jimbozhang commited on
Commit
548d485
1 Parent(s): 9578f22

Update ced_model/modeling_ced.py

Browse files
Files changed (1) hide show
  1. ced_model/modeling_ced.py +1 -3
ced_model/modeling_ced.py CHANGED
@@ -457,9 +457,7 @@ class CedModel(CedPreTrainedModel):
457
  n_splits = 1
458
 
459
  x = self.forward_features(x)
460
- if n_splits > 1:
461
- x = torch.flatten(x, 0, 1)
462
- x = torch.unsqueeze(x, 0)
463
 
464
  return SequenceClassifierOutput(logits=x)
465
 
 
457
  n_splits = 1
458
 
459
  x = self.forward_features(x)
460
+ x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
 
 
461
 
462
  return SequenceClassifierOutput(logits=x)
463