# Imports import gradio as gr import os import matplotlib.pyplot as plt import torch import torchaudio from torch import nn import pytorch_lightning as pl from ema_pytorch import EMA import yaml from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler # Load configs def load_configs(config_path): with open(config_path, 'r') as file: config = yaml.safe_load(file) pl_configs = config['model'] model_configs = config['model']['model'] return pl_configs, model_configs # plot mel spectrogram def plot_mel_spectrogram(sample, sr): transform = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_fft=1024, hop_length=512, n_mels=80, center=True, norm="slaney", ) spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0) # Plot the Mel spectrogram fig = plt.figure(figsize=(7, 4)) plt.imshow(spectrogram, aspect='auto', origin='lower') plt.colorbar(format='%+2.0f dB') plt.xlabel('Frame') plt.ylabel('Mel Bin') plt.title('Mel Spectrogram') plt.tight_layout() return fig # Define PyTorch Lightning model class Model(pl.LightningModule): def __init__( self, lr: float, lr_beta1: float, lr_beta2: float, lr_eps: float, lr_weight_decay: float, ema_beta: float, ema_power: float, model: nn.Module, ): super().__init__() self.lr = lr self.lr_beta1 = lr_beta1 self.lr_beta2 = lr_beta2 self.lr_eps = lr_eps self.lr_weight_decay = lr_weight_decay self.model = model self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power) # Instantiate model (must match model that was trained) def load_model(model_configs, pl_configs) -> nn.Module: # Diffusion model model = DiffusionModel( net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels channels=model_configs['channels'], # U-Net: channels at each layer factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer items=model_configs['items'], # U-Net: number of repeating items at each layer attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item diffusion_t=VDiffusion, # The diffusion method used sampler_t=VSampler # The diffusion sampler used ) # pl model model = Model( lr=pl_configs['lr'], lr_beta1=pl_configs['lr_beta1'], lr_beta2=pl_configs['lr_beta2'], lr_eps=pl_configs['lr_eps'], lr_weight_decay=pl_configs['lr_weight_decay'], ema_beta=pl_configs['ema_beta'], ema_power=pl_configs['ema_power'], model=model ) return model # Assign to GPU def assign_to_gpu(model): if torch.cuda.is_available(): model = model.to('cuda') print(f"Device: {model.device}") return model # Load model checkpoint def load_checkpoint(model, ckpt_path) -> None: checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict'] model.load_state_dict(checkpoint) # should output "" # Generate Samples def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768): # load_checkpoint ckpt_path = models[model_name] load_checkpoint(model, ckpt_path) if num_samples > 1: duration = int(duration / 2) # Generate samples with torch.no_grad(): if init_audio: # load audio sample audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device) audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio # Trim audio og_shape = audio_sample.shape if duration < og_shape[2]: audio_sample = audio_sample[:,:,:duration] elif duration > og_shape[2]: # Pad tensor with zeros to match sample length audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2) else: audio_sample = torch.zeros((1, 2, int(duration)), device=model.device) noise_level = 1.0 all_samples = torch.zeros(2, 0) for i in range(num_samples): noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length] audio = (audio_sample * abs(1-noise_level)) + noise # add noise # generate samples generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100 # concatenate all samples: all_samples = torch.concat((all_samples, generated_sample), dim=1) torch.cuda.empty_cache() fig = plot_mel_spectrogram(all_samples, sr) plt.title(f"{model_name} Mel Spectrogram") return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot # Define Constants & initialize model # load model & configs sr = 44100 # sampling rate config_path = "saved_models/config.yaml" # config path pl_configs, model_configs = load_configs(config_path) model = load_model(model_configs, pl_configs) model = assign_to_gpu(model) models = { "Kicks": "saved_models/kicks/kicks_v7.ckpt", "Snares": "saved_models/snares/snares_v0.ckpt", "Hi-hats": "saved_models/hihats/hihats_v2.ckpt", "Percussion": "saved_models/percussion/percussion_v0.ckpt" } intro = """

Tiny Audio Diffusion

Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM)

GitHub Repo | Repo Tutorial Video | Towards Data Science Article

""" with gr.Blocks() as demo: # Layout gr.HTML(intro) with gr.Row(equal_height=False): with gr.Column(): # Inputs model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model") num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3) num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15) # Conditioning Audio Input with gr.Accordion("Input Audio (optional)", open=False): init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.
Leaving input audio blank results in unconditional generation.') init_audio = gr.Audio(label="Input Audio Sample") init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True) # Examples gr.Examples( examples=[ os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"), os.path.join(os.path.dirname(__file__), "samples", "snare.wav"), os.path.join(os.path.dirname(__file__), "samples", "kick.wav"), os.path.join(os.path.dirname(__file__), "samples", "hihat.wav") ], inputs=init_audio, label="Example Audio Inputs" ) # Buttons with gr.Row(): with gr.Column(): clear_button = gr.Button(value="Reset All") with gr.Column(): generate_btn = gr.Button("Generate Samples!") with gr.Column(): # Outputs output_audio = gr.Audio(label="Generated Audio Sample") output_plot = gr.Plot(label="Generated Audio Spectrogram") # Functionality # Generate samples generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot]) # clear_button button to reset everything clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot]) if __name__ == "__main__": demo.launch()