|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
from ..interfaces import UpstreamBase |
|
|
from .audio import waveform_to_examples |
|
|
from .vggish import VGGish |
|
|
|
|
|
|
|
|
class UpstreamExpert(UpstreamBase): |
|
|
def __init__(self, ckpt, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.model = VGGish(ckpt, **kwargs) |
|
|
|
|
|
def get_downsample_rates(self, key: str) -> int: |
|
|
return 16000 |
|
|
|
|
|
def forward(self, wavs): |
|
|
device = wavs[0].device |
|
|
|
|
|
outputs = [] |
|
|
for wav in wavs: |
|
|
|
|
|
feature = waveform_to_examples(wav.detach().cpu().numpy()) |
|
|
feature = self.model(feature.to(device)) |
|
|
if feature.dim() == 1: |
|
|
feature = feature.unsqueeze(0) |
|
|
outputs.append(feature) |
|
|
outputs = pad_sequence(outputs, batch_first=True) |
|
|
|
|
|
return { |
|
|
"last_hidden_state": outputs, |
|
|
"hidden_states": [outputs], |
|
|
} |
|
|
|