Spaces:
Runtime error
Runtime error
mattricesound
commited on
Commit
•
0e3a05d
1
Parent(s):
0fbacb2
Add target cropping if outputs are different length
Browse files- remfx/models.py +7 -20
- remfx/tcn.py +0 -2
remfx/models.py
CHANGED
@@ -13,6 +13,7 @@ from remfx.utils import FADLoss, spectrogram
|
|
13 |
from remfx.dptnet import DPTNet_base
|
14 |
from remfx.dcunet import RefineSpectrogramUnet
|
15 |
from remfx.tcn import TCN
|
|
|
16 |
|
17 |
|
18 |
class RemFX(pl.LightningModule):
|
@@ -223,21 +224,14 @@ class DCUNetModel(nn.Module):
|
|
223 |
def forward(self, batch):
|
224 |
x, target = batch
|
225 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
226 |
-
#
|
227 |
-
if output.shape[-1]
|
228 |
-
|
229 |
-
elif output.shape[-1] < target.shape[-1]:
|
230 |
-
output = F.pad(output, (0, target.shape[-1] - output.shape[-1]))
|
231 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
232 |
return loss, output
|
233 |
|
234 |
def sample(self, x: Tensor) -> Tensor:
|
235 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
236 |
-
# Pad or crop to match target
|
237 |
-
if output.shape[-1] > x.shape[-1]:
|
238 |
-
output = output[:, : x.shape[-1]]
|
239 |
-
elif output.shape[-1] < x.shape[-1]:
|
240 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
241 |
return output
|
242 |
|
243 |
|
@@ -253,21 +247,14 @@ class TCNModel(nn.Module):
|
|
253 |
def forward(self, batch):
|
254 |
x, target = batch
|
255 |
output = self.model(x) # B x 1 x T
|
256 |
-
#
|
257 |
-
if output.shape[-1]
|
258 |
-
|
259 |
-
elif output.shape[-1] < x.shape[-1]:
|
260 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
261 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
262 |
return loss, output
|
263 |
|
264 |
def sample(self, x: Tensor) -> Tensor:
|
265 |
output = self.model(x) # B x 1 x T
|
266 |
-
# Pad or crop to match target
|
267 |
-
if output.shape[-1] > x.shape[-1]:
|
268 |
-
output = output[:, : x.shape[-1]]
|
269 |
-
elif output.shape[-1] < x.shape[-1]:
|
270 |
-
output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
|
271 |
return output
|
272 |
|
273 |
|
|
|
13 |
from remfx.dptnet import DPTNet_base
|
14 |
from remfx.dcunet import RefineSpectrogramUnet
|
15 |
from remfx.tcn import TCN
|
16 |
+
from remfx.utils import causal_crop
|
17 |
|
18 |
|
19 |
class RemFX(pl.LightningModule):
|
|
|
224 |
def forward(self, batch):
|
225 |
x, target = batch
|
226 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
227 |
+
# Crop target to match output
|
228 |
+
if output.shape[-1] < target.shape[-1]:
|
229 |
+
target = causal_crop(target, output.shape[-1])
|
|
|
|
|
230 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
231 |
return loss, output
|
232 |
|
233 |
def sample(self, x: Tensor) -> Tensor:
|
234 |
output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
|
|
|
|
|
|
|
|
|
|
|
235 |
return output
|
236 |
|
237 |
|
|
|
247 |
def forward(self, batch):
|
248 |
x, target = batch
|
249 |
output = self.model(x) # B x 1 x T
|
250 |
+
# Crop target to match output
|
251 |
+
if output.shape[-1] < target.shape[-1]:
|
252 |
+
target = causal_crop(target, output.shape[-1])
|
|
|
|
|
253 |
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
254 |
return loss, output
|
255 |
|
256 |
def sample(self, x: Tensor) -> Tensor:
|
257 |
output = self.model(x) # B x 1 x T
|
|
|
|
|
|
|
|
|
|
|
258 |
return output
|
259 |
|
260 |
|
remfx/tcn.py
CHANGED
@@ -25,8 +25,6 @@ class TCNBlock(nn.Module):
|
|
25 |
self.stride = stride
|
26 |
|
27 |
self.crop_fn = crop_fn
|
28 |
-
# Assumes stride of 1
|
29 |
-
padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
|
30 |
self.conv1 = nn.Conv1d(
|
31 |
in_ch,
|
32 |
out_ch,
|
|
|
25 |
self.stride = stride
|
26 |
|
27 |
self.crop_fn = crop_fn
|
|
|
|
|
28 |
self.conv1 = nn.Conv1d(
|
29 |
in_ch,
|
30 |
out_ch,
|