Spaces:
Sleeping
Sleeping
Commit
·
aecaaea
1
Parent(s):
f8fea2a
Fix causal cropping for input metrics
Browse files- remfx/models.py +4 -2
remfx/models.py
CHANGED
@@ -188,8 +188,9 @@ class RemFX(pl.LightningModule):
|
|
188 |
|
189 |
loss, output = self.model((x, y))
|
190 |
# Crop target to match output
|
|
|
191 |
if output.shape[-1] < y.shape[-1]:
|
192 |
-
|
193 |
self.log(f"{mode}_loss", loss)
|
194 |
# Metric logging
|
195 |
with torch.no_grad():
|
@@ -204,13 +205,14 @@ class RemFX(pl.LightningModule):
|
|
204 |
continue
|
205 |
self.log(
|
206 |
f"{mode}_{metric}",
|
207 |
-
negate * self.metrics[metric](output,
|
208 |
on_step=False,
|
209 |
on_epoch=True,
|
210 |
logger=True,
|
211 |
prog_bar=True,
|
212 |
sync_dist=True,
|
213 |
)
|
|
|
214 |
self.log(
|
215 |
f"Input_{metric}",
|
216 |
negate * self.metrics[metric](x, y),
|
|
|
188 |
|
189 |
loss, output = self.model((x, y))
|
190 |
# Crop target to match output
|
191 |
+
target = y
|
192 |
if output.shape[-1] < y.shape[-1]:
|
193 |
+
target = causal_crop(y, output.shape[-1])
|
194 |
self.log(f"{mode}_loss", loss)
|
195 |
# Metric logging
|
196 |
with torch.no_grad():
|
|
|
205 |
continue
|
206 |
self.log(
|
207 |
f"{mode}_{metric}",
|
208 |
+
negate * self.metrics[metric](output, target),
|
209 |
on_step=False,
|
210 |
on_epoch=True,
|
211 |
logger=True,
|
212 |
prog_bar=True,
|
213 |
sync_dist=True,
|
214 |
)
|
215 |
+
|
216 |
self.log(
|
217 |
f"Input_{metric}",
|
218 |
negate * self.metrics[metric](x, y),
|