Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import torch.nn as nn | |
| from collections import OrderedDict | |
| import json | |
| from models.tta.autoencoder.autoencoder import AutoencoderKL | |
| from models.tta.ldm.inference_utils.vocoder import Generator | |
| from models.tta.ldm.audioldm import AudioLDM | |
| from transformers import T5EncoderModel, AutoTokenizer | |
| from diffusers import PNDMScheduler | |
| import matplotlib.pyplot as plt | |
| from scipy.io.wavfile import write | |
| from utils.util import load_config | |
| import gradio as gr | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def build_autoencoderkl(cfg, device): | |
| autoencoderkl = AutoencoderKL(cfg.model.autoencoderkl) | |
| autoencoder_path = cfg.model.autoencoder_path | |
| checkpoint = torch.load(autoencoder_path, map_location="cpu") | |
| autoencoderkl.load_state_dict(checkpoint["model"]) | |
| autoencoderkl = autoencoderkl.to(device=device) | |
| autoencoderkl.requires_grad_(requires_grad=False) | |
| autoencoderkl.eval() | |
| return autoencoderkl | |
| def build_textencoder(device): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) | |
| text_encoder = T5EncoderModel.from_pretrained("t5-base") | |
| except: | |
| tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer") | |
| text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder") | |
| text_encoder = text_encoder.to(device=device) | |
| text_encoder.requires_grad_(requires_grad=False) | |
| text_encoder.eval() | |
| return tokenizer, text_encoder | |
| def build_vocoder(device): | |
| config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json") | |
| with open(config_file) as f: | |
| data = f.read() | |
| json_config = json.loads(data) | |
| h = AttrDict(json_config) | |
| vocoder = Generator(h).to(device) | |
| checkpoint_dict = torch.load( | |
| "ckpts/tta/hifigan_checkpoints/g_01250000", map_location=device | |
| ) | |
| vocoder.load_state_dict(checkpoint_dict["generator"]) | |
| return vocoder | |
| def build_model(cfg): | |
| model = AudioLDM(cfg.model.audioldm) | |
| return model | |
| def get_text_embedding(text, tokenizer, text_encoder, device): | |
| prompt = [text] | |
| text_input = tokenizer( | |
| prompt, | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| padding="do_not_pad", | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| return text_embeddings | |
| def tta_inference( | |
| text, | |
| guidance_scale=4, | |
| diffusion_steps=100, | |
| ): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| os.environ["WORK_DIR"] = "./" | |
| cfg = load_config("egs/tta/audioldm/exp_config.json") | |
| autoencoderkl = build_autoencoderkl(cfg, device) | |
| tokenizer, text_encoder = build_textencoder(device) | |
| vocoder = build_vocoder(device) | |
| model = build_model(cfg) | |
| checkpoint_path = "ckpts/tta/audioldm_debug_latent_size_4_5_39/checkpoints/step-0570000_loss-0.2521.pt" | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model"]) | |
| model = model.to(device) | |
| text_embeddings = get_text_embedding(text, tokenizer, text_encoder, device) | |
| num_steps = diffusion_steps | |
| noise_scheduler = PNDMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| skip_prk_steps=True, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| prediction_type="epsilon", | |
| ) | |
| noise_scheduler.set_timesteps(num_steps) | |
| latents = torch.randn( | |
| ( | |
| 1, | |
| cfg.model.autoencoderkl.z_channels, | |
| 80 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)), | |
| 624 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)), | |
| ) | |
| ).to(device) | |
| model.eval() | |
| for t in tqdm(noise_scheduler.timesteps): | |
| t = t.to(device) | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = noise_scheduler.scale_model_input( | |
| latent_model_input, timestep=t | |
| ) | |
| # print(latent_model_input.shape) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| noise_pred = model( | |
| latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings | |
| ) | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| print(guidance_scale) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| # print(latents.shape) | |
| latents_out = latents | |
| with torch.no_grad(): | |
| mel_out = autoencoderkl.decode(latents_out) | |
| melspec = mel_out[0, 0].cpu().detach().numpy() | |
| vocoder.eval() | |
| vocoder.remove_weight_norm() | |
| with torch.no_grad(): | |
| melspec = np.expand_dims(melspec, 0) | |
| melspec = torch.FloatTensor(melspec).to(device) | |
| y = vocoder(melspec) | |
| audio = y.squeeze() | |
| audio = audio * 32768.0 | |
| audio = audio.cpu().numpy().astype("int16") | |
| os.makedirs("result", exist_ok=True) | |
| write(os.path.join("result", text + ".wav"), 16000, audio) | |
| return os.path.join("result", text + ".wav") | |
| demo_inputs = [ | |
| gr.Textbox( | |
| value="birds singing and a man whistling", | |
| label="Text prompt you want to generate", | |
| type="text", | |
| ), | |
| gr.Slider( | |
| 1, | |
| 10, | |
| value=4, | |
| step=1, | |
| label="Classifier free guidance", | |
| ), | |
| gr.Slider( | |
| 50, | |
| 1000, | |
| value=100, | |
| step=1, | |
| label="Diffusion Inference Steps", | |
| info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", | |
| ), | |
| ] | |
| demo_outputs = gr.Audio(label="") | |
| demo = gr.Interface( | |
| fn=tta_inference, | |
| inputs=demo_inputs, | |
| outputs=demo_outputs, | |
| title="Amphion Text to Audio", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |