File size: 1,124 Bytes
ad2cddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# custom_interface.py for CommonAccent English Accent Classifier
# Downloaded from: https://huggingface.co/Jzuluaga/accent-id-commonaccent_xlsr-en-english/blob/main/custom_interface.py
# This file is required by the SpeechBrain foreign_class interface.

import torch
from speechbrain.pretrained.interfaces import Pretrained

class CustomEncoderWav2vec2Classifier(Pretrained):
    MODULES_NEEDED = ["model", "mean_var_norm", "label_encoder"]
    HPARAMS_NEEDED = ["sample_rate"]

    def classify_file(self, path):
        signal, fs = self.load_audio(path)
        return self.classify_batch(signal, fs)

    def classify_batch(self, signal, fs):
        if fs != self.hparams.sample_rate:
            signal = self.resample(signal, fs, self.hparams.sample_rate)
        signal = self.modules.mean_var_norm(signal, torch.tensor([1]))
        embeddings = self.modules.model.encode_batch(signal)
        out_prob = self.modules.model.classify_batch(embeddings)
        score, index = torch.max(out_prob, dim=1)
        text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index, text_lab