import gradio as gr import json import torch import wavio import numpy as np from tqdm import tqdm from huggingface_hub import snapshot_download from audioldm.audio.stft import TacotronSTFT from audioldm.variational_autoencoder import AutoencoderKL from transformers import AutoTokenizer, T5ForConditionalGeneration from modelling_deberta_v2 import DebertaV2ForTokenClassificationRegression import sys sys.path.insert(0, "diffusers/src") from diffusers import DDPMScheduler from models import MusicAudioDiffusion from gradio import Markdown import spaces # Automatic device detection if torch.cuda.is_available(): device_type = "cuda" device_selection = "cuda:0" else: device_type = "cpu" device_selection = "cpu" class MusicFeaturePredictor: def __init__(self, path, device=device_selection, cache_dir=None, local_files_only=False): self.beats_tokenizer = AutoTokenizer.from_pretrained( "microsoft/deberta-v3-large", use_fast=False, cache_dir=cache_dir, local_files_only=local_files_only, ) self.beats_model = DebertaV2ForTokenClassificationRegression.from_pretrained( "microsoft/deberta-v3-large", cache_dir=cache_dir, local_files_only=local_files_only, ) self.beats_model.eval() self.beats_model.to(device) beats_ckpt = f"{path}/beats/microsoft-deberta-v3-large.pt" beats_weight = torch.load(beats_ckpt, map_location="cpu") self.beats_model.load_state_dict(beats_weight) self.chords_tokenizer = AutoTokenizer.from_pretrained( "google/flan-t5-large", cache_dir=cache_dir, local_files_only=local_files_only, ) self.chords_model = T5ForConditionalGeneration.from_pretrained( "google/flan-t5-large", cache_dir=cache_dir, local_files_only=local_files_only, ) self.chords_model.eval() self.chords_model.to(device) chords_ckpt = f"{path}/chords/flan-t5-large.bin" chords_weight = torch.load(chords_ckpt, map_location="cpu") self.chords_model.load_state_dict(chords_weight) def generate_beats(self, prompt): tokenized = self.beats_tokenizer( prompt, max_length=512, padding=True, truncation=True, return_tensors="pt" ) tokenized = {k: v.to(self.beats_model.device) for k, v in tokenized.items()} with torch.no_grad(): out = self.beats_model(**tokenized) max_beat = ( 1 + torch.argmax(out["logits"][:, 0, :], -1).detach().cpu().numpy() ).tolist()[0] intervals = ( out["values"][:, :, 0] .detach() .cpu() .numpy() .astype("float32") .round(4) .tolist() ) intervals = np.cumsum(intervals) predicted_beats_times = [] for t in intervals: if t < 10: predicted_beats_times.append(round(t, 2)) else: break predicted_beats_times = list(np.array(predicted_beats_times)[:50]) if len(predicted_beats_times) == 0: predicted_beats = [[], []] else: beat_counts = [] for i in range(len(predicted_beats_times)): beat_counts.append(float(1.0 + np.mod(i, max_beat))) predicted_beats = [[predicted_beats_times, beat_counts]] return max_beat, predicted_beats_times, predicted_beats def generate(self, prompt): max_beat, predicted_beats_times, predicted_beats = self.generate_beats(prompt) chords_prompt = "Caption: {} \\n Timestamps: {} \\n Max Beat: {}".format( prompt, " , ".join([str(round(t, 2)) for t in predicted_beats_times]), max_beat, ) tokenized = self.chords_tokenizer( chords_prompt, max_length=512, padding=True, truncation=True, return_tensors="pt", ) tokenized = {k: v.to(self.chords_model.device) for k, v in tokenized.items()} generated_chords = self.chords_model.generate( input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"], min_length=8, max_length=128, num_beams=5, early_stopping=True, num_return_sequences=1, ) generated_chords = self.chords_tokenizer.decode( generated_chords[0], skip_special_tokens=True, clean_up_tokenization_spaces=True, ).split(" n ") predicted_chords, predicted_chords_times = [], [] for item in generated_chords: c, ct = item.split(" at ") predicted_chords.append(c) predicted_chords_times.append(float(ct)) return predicted_beats, predicted_chords, predicted_chords_times class Mustango: def __init__( self, name="declare-lab/mustango", device=device_selection, cache_dir=None, local_files_only=False, ): path = snapshot_download(repo_id=name, cache_dir=cache_dir) self.music_model = MusicFeaturePredictor( path, device, cache_dir=cache_dir, local_files_only=local_files_only ) vae_config = json.load(open(f"{path}/configs/vae_config.json")) stft_config = json.load(open(f"{path}/configs/stft_config.json")) main_config = json.load(open(f"{path}/configs/main_config.json")) self.vae = AutoencoderKL(**vae_config).to(device) self.stft = TacotronSTFT(**stft_config).to(device) self.model = MusicAudioDiffusion( main_config["text_encoder_name"], main_config["scheduler_name"], unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json", ).to(device) # self.model.device = device vae_weights = torch.load( f"{path}/vae/pytorch_model_vae.bin", map_location=device ) stft_weights = torch.load( f"{path}/stft/pytorch_model_stft.bin", map_location=device ) main_weights = torch.load( f"{path}/ldm/pytorch_model_ldm.bin", map_location=device ) self.vae.load_state_dict(vae_weights) self.stft.load_state_dict(stft_weights) self.model.load_state_dict(main_weights) print("Successfully loaded checkpoint from:", name) self.vae.eval() self.stft.eval() self.model.eval() self.scheduler = DDPMScheduler.from_pretrained( main_config["scheduler_name"], subfolder="scheduler" ) def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True): """Genrate music for a single prompt string.""" with torch.no_grad(): beats, chords, chords_times = self.music_model.generate(prompt) latents = self.model.inference( [prompt], beats, [chords], [chords_times], self.scheduler, steps, guidance, samples, disable_progress, ) mel = self.vae.decode_first_stage(latents) wave = self.vae.decode_to_waveform(mel) return wave[0] # Initialize Mustango mustango = Mustango(device="cpu") mustango.vae.to(device_type) mustango.stft.to(device_type) mustango.model.to(device_type) mustango.music_model.beats_model.to(device_type) mustango.music_model.chords_model.to(device_type) # if torch.cuda.is_available(): # mustango = Mustango() # else: # mustango = Mustango(device="cpu") # output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False) @spaces.GPU(duration=120) def gradio_generate(prompt, steps, guidance): output_wave = mustango.generate(prompt, steps, guidance) # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav" output_filename = "temp.wav" wavio.write(output_filename, output_wave, rate=16000, sampwidth=2) return output_filename title="Mustango: Toward Controllable Text-to-Music Generation" description_text = """
For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings.
Generate music using Mustango by providing a text prompt.
This is the demo for Mustango for controllable text to music generation: Read our paper.