File size: 793 Bytes
8b70882
 
 
fa827db
 
 
b57ac94
fa827db
0f77019
fa827db
 
85046ed
8b70882
 
 
 
 
 
fa827db
 
8b70882
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import speechbrain as sb

class FeatureScaler(torch.nn.Module):
    def __init__(self, num_in, scale):
        super().__init__()
        self.scaler = torch.ones((num_in,))* scale

    def forward(self, x):
        return x * self.scaler

class CustomInterface(sb.pretrained.interfaces.Pretrained):
    MODULES_NEEDED = ["normalizer"]
    HPARAMS_NEEDED = ["feature_extractor"]

    def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
        feats = self.hparams.feature_extractor(audio)
        normalized = self.mods.normalizer(feats, lengths)
        scaled = self.mods.feature_scaler(normalized)
        return scaled

    def feats_from_file(self, path):
        audio = self.load_audio(path)
        return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0)