lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.1 kB
import importlib
from typing import List
import torch
import torch.nn as nn
import torchaudio
from torch.nn.utils.rnn import pad_sequence
SAMPLE_RATE = 16000
class UpstreamExpert(nn.Module):
def __init__(
self,
name: str,
refresh=False,
window_secs: float = 0.16,
stride_secs: float = 0.05,
):
super().__init__()
self.resampler = torchaudio.transforms.Resample(16000, 32000)
self.module = importlib.import_module(f".hear21passt.{name}", __package__)
self.model = self.module.load_model(
timestamp_window=window_secs * 1000,
timestamp_hop=stride_secs * 1000,
)
self.stride_secs = stride_secs
def get_downsample_rates(self, key=None):
return int(self.stride_secs * SAMPLE_RATE)
def forward(self, wavs: List[torch.Tensor]):
wavs = pad_sequence(wavs, batch_first=True)
wavs = self.resampler(wavs)
embs, timestamps = self.module.get_timestamp_embeddings(wavs, self.model)
return {
"hidden_states": [embs],
}