Edit model card


Vocoder with HiFIGAN Unit

Work In Progress ....

import pathlib as pl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
from speechbrain.inference.vocoders import UnitHIFIGAN
from speechbrain.lobes.models.huggingface_transformers import (
    hubert,
    wav2vec2,
    wavlm,
)
from speechbrain.lobes.models.huggingface_transformers.discrete_ssl import (
    DiscreteSSL,
)

ENCODER_CLASSES = {
    "HuBERT": hubert.HuBERT,
    "Wav2Vec2": wav2vec2.Wav2Vec2,
    "WavLM": wavlm.WavLM,
}

kmeans_folder = "poonehmousavi/SSL_Quantization"
kmeans_dataset = "LJSpeech"  # LibriSpeech-100-360-500
num_clusters = 1000
encoder_type = "HuBERT"  # one of [HuBERT, Wav2Vec2, WavLM]
encoder_source = "facebook/hubert-large-ll60k"
layer = [3, 7, 12, 18, 23]
vocoder_source = (
    "chaanks/hifigan-unit-hubert-ll60k-l3-7-12-18-23-k1000-ljspeech-ljspeech"
)
save_path = pl.Path(".tmpdir")
device = "cuda"
sample_rate = 16000

wav = "chaanks/hifigan-unit-hubert-ll60k-l3-7-12-18-23-k1000-ljspeech-ljspeech/test.wav"

encoder_class = ENCODER_CLASSES[encoder_type]
encoder = encoder_class(
    source=encoder_source,
    save_path=(save_path / "encoder").as_posix(),
    output_norm=False,
    freeze=True,
    freeze_feature_extractor=True,
    apply_spec_augment=False,
    output_all_hiddens=True,
).to(device)

discrete_encoder = DiscreteSSL(
    save_path=(save_path / "discrete_encoder").as_posix(),
    ssl_model=encoder,
    kmeans_dataset=kmeans_dataset,
    kmeans_repo_id=kmeans_folder,
    num_clusters=num_clusters,
)

vocoder = UnitHIFIGAN.from_hparams(
    source=vocoder_source,
    run_opts={"device": str(device)},
    savedir=(save_path / "vocoder").as_posix(),
)

audio = vocoder.load_audio(wav)
audio = audio.unsqueeze(0).to(device)

deduplicates = [False for _ in layer]
bpe_tokenizers = [None for _ in layer]
tokens, _, _ = discrete_encoder(
    audio,
    SSL_layers=layer,
    deduplicates=deduplicates,
    bpe_tokenizers=bpe_tokenizers,
)
tokens = tokens.cpu().squeeze(0)

num_layer = len(layer)
offsets = torch.arange(num_layer) * num_clusters
tokens = tokens + offsets

waveform = vocoder.decode_unit(tokens)
torchaudio.save("pred.wav", waveform.cpu(), sample_rate=sample_rate)
Downloads last month
1
Inference API (serverless) has been turned off for this model.