Spaces:
Running
Running
import os | |
import re | |
import yaml | |
import torch | |
import torchaudio | |
import audioldm2.latent_diffusion.modules.phoneme_encoder.text as text | |
from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion | |
from audioldm2.latent_diffusion.util import get_vits_phoneme_ids_no_padding | |
from audioldm2.utils import default_audioldm_config, download_checkpoint | |
import os | |
# CACHE_DIR = os.getenv( | |
# "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") | |
# ) | |
def seed_everything(seed): | |
import random, os | |
import numpy as np | |
import torch | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
def text2phoneme(data): | |
return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"]) | |
def text_to_filename(text): | |
return text.replace(" ", "_").replace("'", "_").replace('"', "_") | |
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): | |
norm_mean = -4.2677393 | |
norm_std = 4.5689974 | |
if sampling_rate != 16000: | |
waveform_16k = torchaudio.functional.resample( | |
waveform, orig_freq=sampling_rate, new_freq=16000 | |
) | |
else: | |
waveform_16k = waveform | |
waveform_16k = waveform_16k - waveform_16k.mean() | |
fbank = torchaudio.compliance.kaldi.fbank( | |
waveform_16k, | |
htk_compat=True, | |
sample_frequency=16000, | |
use_energy=False, | |
window_type="hanning", | |
num_mel_bins=128, | |
dither=0.0, | |
frame_shift=10, | |
) | |
TARGET_LEN = log_mel_spec.size(0) | |
# cut and pad | |
n_frames = fbank.shape[0] | |
p = TARGET_LEN - n_frames | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
fbank = m(fbank) | |
elif p < 0: | |
fbank = fbank[:TARGET_LEN, :] | |
fbank = (fbank - norm_mean) / (norm_std * 2) | |
return {"ta_kaldi_fbank": fbank} # [1024, 128] | |
def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1): | |
text = [text] * batchsize | |
if(transcription): | |
transcription = text2phoneme(transcription) | |
transcription = [transcription] * batchsize | |
if batchsize < 1: | |
print("Warning: Batchsize must be at least 1. Batchsize is set to .") | |
if fbank is None: | |
fbank = torch.zeros( | |
(batchsize, 1024, 64) | |
) # Not used, here to keep the code format | |
else: | |
fbank = torch.FloatTensor(fbank) | |
fbank = fbank.expand(batchsize, 1024, 64) | |
assert fbank.size(0) == batchsize | |
stft = torch.zeros((batchsize, 1024, 512)) # Not used | |
phonemes = get_vits_phoneme_ids_no_padding(transcription) | |
if waveform is None: | |
waveform = torch.zeros((batchsize, 160000)) # Not used | |
ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128)) | |
else: | |
waveform = torch.FloatTensor(waveform) | |
waveform = waveform.expand(batchsize, -1) | |
assert waveform.size(0) == batchsize | |
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank) | |
batch = { | |
"text": text, # list | |
"fname": [text_to_filename(t) for t in text], # list | |
"waveform": waveform, | |
"stft": stft, | |
"log_mel_spec": fbank, | |
"ta_kaldi_fbank": ta_kaldi_fbank, | |
} | |
batch.update(phonemes) | |
return batch | |
def round_up_duration(duration): | |
return int(round(duration / 2.5) + 1) * 2.5 | |
# def split_clap_weight_to_pth(checkpoint): | |
# if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")): | |
# return | |
# print("Constructing the weight for the CLAP model.") | |
# include_keys = "cond_stage_models.0.cond_stage_models.0.model." | |
# new_state_dict = {} | |
# for each in checkpoint["state_dict"].keys(): | |
# if include_keys in each: | |
# new_state_dict[each.replace(include_keys, "module.")] = checkpoint[ | |
# "state_dict" | |
# ][each] | |
# torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth")) | |
def build_model(ckpt_path=None, config=None, device=None, model_name="audioldm2-full"): | |
if device is None or device == "auto": | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print("Loading AudioLDM-2: %s" % model_name) | |
print("Loading model on %s" % device) | |
ckpt_path = download_checkpoint(model_name) | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config(model_name) | |
# # Use text as condition instead of using waveform during training | |
config["model"]["params"]["device"] = device | |
# config["model"]["params"]["cond_stage_key"] = "text" | |
# No normalization here | |
latent_diffusion = LatentDiffusion(**config["model"]["params"]) | |
resume_from_checkpoint = ckpt_path | |
checkpoint = torch.load(resume_from_checkpoint, map_location=device) | |
latent_diffusion.load_state_dict(checkpoint["state_dict"]) | |
latent_diffusion.eval() | |
latent_diffusion = latent_diffusion.to(device) | |
return latent_diffusion | |
def text_to_audio( | |
latent_diffusion, | |
text, | |
transcription="", | |
seed=42, | |
ddim_steps=200, | |
duration=10, | |
batchsize=1, | |
guidance_scale=3.5, | |
n_candidate_gen_per_text=3, | |
latent_t_per_second=25.6, | |
config=None, | |
): | |
seed_everything(int(seed)) | |
waveform = None | |
batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize) | |
latent_diffusion.latent_t_size = int(duration * latent_t_per_second) | |
with torch.no_grad(): | |
waveform = latent_diffusion.generate_batch( | |
batch, | |
unconditional_guidance_scale=guidance_scale, | |
ddim_steps=ddim_steps, | |
n_gen=n_candidate_gen_per_text, | |
duration=duration, | |
) | |
return waveform | |