Spaces:
Runtime error
Runtime error
File size: 2,678 Bytes
bb5a96d 4596c24 bb5a96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
from src.models.lcnn import LCNN
from src import frontends
from src.commons import WHISPER_MODEL_WEIGHTS_PATH
class WhisperLCNN(LCNN):
def __init__(self, input_channels, freeze_encoder, **kwargs):
super().__init__(input_channels=input_channels, **kwargs)
self.device = kwargs['device']
checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'))
dims = ModelDimensions(**checkpoint["dims"].__dict__)
model = Whisper(dims)
model = model.to(self.device)
model.load_state_dict(checkpoint["model_state_dict"])
self.whisper_model = model
if freeze_encoder:
for param in self.whisper_model.parameters():
param.requires_grad = False
def compute_whisper_features(self, x):
specs = []
for sample in x:
specs.append(log_mel_spectrogram(sample))
x = torch.stack(specs)
x = self.whisper_model(x)
x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
x = x.repeat(
(1, 1, 1, 2)
) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
return x
def forward(self, x):
# we assume that the data is correct (i.e. 30s)
x = self.compute_whisper_features(x)
out = self._compute_embedding(x)
return out
class WhisperMultiFrontLCNN(WhisperLCNN):
def __init__(self, input_channels, freeze_encoder, **kwargs):
super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs)
self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
print(f"Using {self.frontend} frontend!")
def forward(self, x):
# Frontend computation
frontend_x = self.frontend(x)
x = self.compute_whisper_features(x)
x = torch.cat([x, frontend_x], 1)
out = self._compute_embedding(x)
return out
if __name__ == "__main__":
import numpy as np
input_channels = 1
device = "cpu"
classifier = WhisperLCNN(
input_channels=input_channels,
freeze_encoder=True,
device=device,
)
input_channels = 2
classifier_2 = WhisperMultiFrontLCNN(
input_channels=input_channels,
freeze_encoder=True,
device=device,
frontend_algorithm="lfcc"
)
x = np.random.rand(2, 30 * 16_000).astype(np.float32)
x = torch.from_numpy(x)
out = classifier(x)
print(out.shape)
out = classifier_2(x)
print(out.shape)
|