Aku Rouhe commited on
Commit
b57ac94
1 Parent(s): 0f77019
Files changed (1) hide show
  1. custom.py +1 -1
custom.py CHANGED
@@ -4,7 +4,7 @@ import speechbrain as sb
4
  class FeatureScaler(torch.nn.Module):
5
  def __init__(self, num_in, scale):
6
  super().__init__()
7
- self.scaler = torch.eye(num_in) * scale
8
 
9
  def forward(self, x):
10
  return x * self.scaler
 
4
  class FeatureScaler(torch.nn.Module):
5
  def __init__(self, num_in, scale):
6
  super().__init__()
7
+ self.scaler = torch.ones((num_in,))* scale
8
 
9
  def forward(self, x):
10
  return x * self.scaler