File size: 2,678 Bytes
bb5a96d
 
 
 
 
 
 
 
 
 
 
 
 
 
4596c24
bb5a96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch

from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
from src.models.lcnn import LCNN
from src import frontends
from src.commons import WHISPER_MODEL_WEIGHTS_PATH


class WhisperLCNN(LCNN):

    def __init__(self, input_channels, freeze_encoder, **kwargs):
        super().__init__(input_channels=input_channels, **kwargs)

        self.device = kwargs['device']
        checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'))
        dims = ModelDimensions(**checkpoint["dims"].__dict__)
        model = Whisper(dims)
        model = model.to(self.device)
        model.load_state_dict(checkpoint["model_state_dict"])
        self.whisper_model = model
        if freeze_encoder:
            for param in self.whisper_model.parameters():
                param.requires_grad = False

    def compute_whisper_features(self, x):
        specs = []
        for sample in x:
            specs.append(log_mel_spectrogram(sample))
        x = torch.stack(specs)
        x = self.whisper_model(x)

        x = x.permute(0, 2, 1)  # (bs, frames, 3 x n_lfcc)
        x = x.unsqueeze(1)  # (bs, 1, frames, 3 x n_lfcc)
        x = x.repeat(
            (1, 1, 1, 2)
        )  # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
        return x

    def forward(self, x):
        # we assume that the data is correct (i.e. 30s)
        x = self.compute_whisper_features(x)
        out = self._compute_embedding(x)
        return out


class WhisperMultiFrontLCNN(WhisperLCNN):

    def __init__(self, input_channels, freeze_encoder, **kwargs):
        super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs)

        self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
        print(f"Using {self.frontend} frontend!")

    def forward(self, x):
        # Frontend computation
        frontend_x = self.frontend(x)
        x = self.compute_whisper_features(x)

        x = torch.cat([x, frontend_x], 1)
        out = self._compute_embedding(x)
        return out


if __name__ == "__main__":
    import numpy as np

    input_channels = 1
    device = "cpu"
    classifier = WhisperLCNN(
        input_channels=input_channels,
        freeze_encoder=True,
        device=device,
    )

    input_channels = 2
    classifier_2 = WhisperMultiFrontLCNN(
        input_channels=input_channels,
        freeze_encoder=True,
        device=device,
        frontend_algorithm="lfcc"
    )
    x = np.random.rand(2, 30 * 16_000).astype(np.float32)
    x = torch.from_numpy(x)

    out = classifier(x)
    print(out.shape)

    out = classifier_2(x)
    print(out.shape)