File size: 2,699 Bytes
bdab1da 412929c bdab1da 412929c bdab1da 4e9d8a1 bdab1da 4e9d8a1 bdab1da 412929c bdab1da 807c6f0 bdab1da 807c6f0 bdab1da ddc4da2 bdab1da 380571b bdab1da 39711bd bdab1da 412929c bdab1da 4e9d8a1 412929c 39711bd 412929c bdab1da 39711bd bdab1da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import argparse
import yaml
import torch
from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config
import time
def make_batch_for_text_to_audio(text, batchsize=1):
text = [text] * batchsize
if batchsize < 1:
print("Warning: Batchsize must be at least 1. Batchsize is set to .")
fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
stft = torch.zeros((batchsize, 1024, 512)) # Not used
waveform = torch.zeros((batchsize, 160000)) # Not used
fname = [""] * batchsize # Not used
batch = (
fbank,
stft,
None,
fname,
waveform,
text,
)
return batch
def build_model(
ckpt_path=None,
config=None,
model_name="audioldm-s-full"
):
print("Load AudioLDM: %s" % model_name)
resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
# if(ckpt_path is None):
# ckpt_path = get_metadata()[model_name]["path"]
# if(not os.path.exists(ckpt_path)):
# download_checkpoint(model_name)
if(torch.cuda.is_available()):
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
if(config is not None):
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config(model_name)
# Use text as condition instead of using waveform during training
config["model"]["params"]["device"] = device
config["model"]["params"]["cond_stage_key"] = "text"
# No normalization here
latent_diffusion = LatentDiffusion(**config["model"]["params"])
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
latent_diffusion.load_state_dict(checkpoint["state_dict"])
latent_diffusion.eval()
latent_diffusion = latent_diffusion.to(device)
latent_diffusion.cond_stage_model.embed_mode = "text"
return latent_diffusion
def duration_to_latent_t_size(duration):
return int(duration * 25.6)
def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None):
seed_everything(int(seed))
batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
with torch.no_grad():
waveform = latent_diffusion.generate_sample(
[batch],
unconditional_guidance_scale=guidance_scale,
n_candidate_gen_per_text=n_candidate_gen_per_text,
duration=duration
)
return waveform
|