|
from pathlib import Path |
|
|
|
import torch |
|
from torch import nn |
|
from einops import pack, unpack |
|
|
|
import fairseq |
|
|
|
from torchaudio.functional import resample |
|
|
|
import logging |
|
logging.root.setLevel(logging.ERROR) |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
class CustomHubert(nn.Module): |
|
""" |
|
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert |
|
or you can train your own |
|
""" |
|
|
|
def __init__( |
|
self, |
|
checkpoint_path, |
|
target_sample_hz=16000, |
|
seq_len_multiple_of=None, |
|
output_layer=9 |
|
): |
|
super().__init__() |
|
self.target_sample_hz = target_sample_hz |
|
self.seq_len_multiple_of = seq_len_multiple_of |
|
self.output_layer = output_layer |
|
|
|
model_path = Path(checkpoint_path) |
|
|
|
assert model_path.exists(), f'path {checkpoint_path} does not exist' |
|
|
|
checkpoint = torch.load(checkpoint_path) |
|
load_model_input = {checkpoint_path: checkpoint} |
|
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) |
|
|
|
self.model = model[0] |
|
self.model.eval() |
|
|
|
@property |
|
def groups(self): |
|
return 1 |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
wav_input, |
|
flatten=True, |
|
input_sample_hz=None |
|
): |
|
device = wav_input.device |
|
|
|
if exists(input_sample_hz): |
|
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) |
|
|
|
embed = self.model( |
|
wav_input, |
|
features_only=True, |
|
mask=False, |
|
output_layer=self.output_layer |
|
) |
|
|
|
embed, packed_shape = pack([embed['x']], '* d') |
|
|
|
|
|
|
|
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) |
|
|
|
if flatten: |
|
return codebook_indices |
|
|
|
codebook_indices, = unpack(codebook_indices, packed_shape, '*') |
|
return codebook_indices |
|
|