Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import torch | |
from torch import nn | |
from einops import pack, unpack | |
import fairseq | |
from torchaudio.functional import resample | |
import logging | |
logging.root.setLevel(logging.ERROR) | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
class CustomHubert(nn.Module): | |
""" | |
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert | |
or you can train your own | |
""" | |
def __init__( | |
self, | |
checkpoint_path, | |
target_sample_hz=16000, | |
seq_len_multiple_of=None, | |
output_layer=9 | |
): | |
super().__init__() | |
self.target_sample_hz = target_sample_hz | |
self.seq_len_multiple_of = seq_len_multiple_of | |
self.output_layer = output_layer | |
model_path = Path(checkpoint_path) | |
assert model_path.exists(), f'path {checkpoint_path} does not exist' | |
checkpoint = torch.load(checkpoint_path) | |
load_model_input = {checkpoint_path: checkpoint} | |
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) | |
self.model = model[0] | |
self.model.eval() | |
def groups(self): | |
return 1 | |
def forward( | |
self, | |
wav_input, | |
flatten=True, | |
input_sample_hz=None | |
): | |
device = wav_input.device | |
if exists(input_sample_hz): | |
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) | |
embed = self.model( | |
wav_input, | |
features_only=True, | |
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code | |
output_layer=self.output_layer | |
) | |
embed, packed_shape = pack([embed['x']], '* d') | |
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) | |
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() | |
if flatten: | |
return codebook_indices | |
codebook_indices, = unpack(codebook_indices, packed_shape, '*') | |
return codebook_indices | |