Vocoder with HiFIGAN Unit

Work In Progress ....

import pathlib as pl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
from speechbrain.inference.vocoders import UnitHIFIGAN
from speechbrain.lobes.models.huggingface_transformers import (
    hubert,
    wav2vec2,
    wavlm,
)
from speechbrain.lobes.models.huggingface_transformers.discrete_ssl import (
    DiscreteSSL,
)

ENCODER_CLASSES = {
    "HuBERT": hubert.HuBERT,
    "Wav2Vec2": wav2vec2.Wav2Vec2,
    "WavLM": wavlm.WavLM,
}

kmeans_folder = "poonehmousavi/SSL_Quantization"
kmeans_dataset = "LJSpeech"  # LibriSpeech-100-360-500
num_clusters = 1000
encoder_type = "HuBERT"  # one of [HuBERT, Wav2Vec2, WavLM]
encoder_source = "facebook/hubert-large-ll60k"
layer = [3, 7, 12, 18, 23]
vocoder_source = (
    "chaanks/hifigan-unit-hubert-ll60k-l3-7-12-18-23-k1000-ljspeech-ljspeech"
)
save_path = pl.Path(".tmpdir")
device = "cuda"
sample_rate = 16000

wav = "chaanks/hifigan-unit-hubert-ll60k-l3-7-12-18-23-k1000-ljspeech-ljspeech/test.wav"

encoder_class = ENCODER_CLASSES[encoder_type]
encoder = encoder_class(
    source=encoder_source,
    save_path=(save_path / "encoder").as_posix(),
    output_norm=False,
    freeze=True,
    freeze_feature_extractor=True,
    apply_spec_augment=False,
    output_all_hiddens=True,
).to(device)

discrete_encoder = DiscreteSSL(
    save_path=(save_path / "discrete_encoder").as_posix(),
    ssl_model=encoder,
    kmeans_dataset=kmeans_dataset,
    kmeans_repo_id=kmeans_folder,
    num_clusters=num_clusters,
)

vocoder = UnitHIFIGAN.from_hparams(
    source=vocoder_source,
    run_opts={"device": str(device)},
    savedir=(save_path / "vocoder").as_posix(),
)

audio = vocoder.load_audio(wav)
audio = audio.unsqueeze(0).to(device)

deduplicates = [False for _ in layer]
bpe_tokenizers = [None for _ in layer]
tokens, _, _ = discrete_encoder(
    audio,
    SSL_layers=layer,
    deduplicates=deduplicates,
    bpe_tokenizers=bpe_tokenizers,
)
tokens = tokens.cpu().squeeze(0)

num_layer = len(layer)
offsets = torch.arange(num_layer) * num_clusters
tokens = tokens + offsets

waveform = vocoder.decode_unit(tokens)
torchaudio.save("pred.wav", waveform.cpu(), sample_rate=sample_rate)
Downloads last month
38
Inference API
Inference API (serverless) has been turned off for this model.