jimbozhang commited on
Commit
9578f22
1 Parent(s): 5e9bb10

Fix forward method for long audios.

Browse files
Files changed (1) hide show
  1. ced_model/modeling_ced.py +6 -12
ced_model/modeling_ced.py CHANGED
@@ -453,19 +453,13 @@ class CedModel(CedPreTrainedModel):
453
  splits = torch.stack(splits[:-1], dim=0)
454
  n_splits = len(splits)
455
  x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
456
- x = self.forward_head(self.ced(x))
457
- x = torch.reshape(
458
- x, (n_splits, -1, self.outputdim)
459
- ) # (spl b) d -> spl b d, spl=n_splits
460
-
461
- if self.config.eval_avg == "mean":
462
- x = x.mean(0)
463
- elif self.config.eval_avg == "max":
464
- x = x.max(0)[0]
465
- else:
466
- raise ValueError(f"Unknown Eval average function ({self.eval_avg})")
467
  else:
468
- x = self.forward_features(x)
 
 
 
 
 
469
 
470
  return SequenceClassifierOutput(logits=x)
471
 
 
453
  splits = torch.stack(splits[:-1], dim=0)
454
  n_splits = len(splits)
455
  x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
 
 
 
 
 
 
 
 
 
 
 
456
  else:
457
+ n_splits = 1
458
+
459
+ x = self.forward_features(x)
460
+ if n_splits > 1:
461
+ x = torch.flatten(x, 0, 1)
462
+ x = torch.unsqueeze(x, 0)
463
 
464
  return SequenceClassifierOutput(logits=x)
465