amuse / app.py
alppo's picture
add description
6f849d1
import os
import sys
import numpy as np
import torch
import gradio as gr
from vae_module import VAE, Encoder, Decoder, loss_function
from config import config
from slicer_module import get_slices
from diffusers import UNet2DConditionModel, DDPMScheduler
from mel_module import Mel
from generator_module import Generator
import shutil
slices_folder = 'slices'
if os.path.exists(slices_folder): # delete previous tracks
shutil.rmtree(slices_folder)
vae = VAE()
vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu')))
vae.to(config.device)
vae.eval()
model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler")
def generate_new_track(audio_paths, progress=gr.Progress(track_tqdm=True)):
for i, audio_path in enumerate(audio_paths):
print(audio_paths, audio_path)
get_slices(audio_path)
embedding = get_embedding()
print("sample latent", embedding.shape)
generator = Generator(config, model, noise_scheduler, vae, embedding, progress_callback=progress)
generator.generate()
return config.generated_track_path
def get_embedding(): # returns middle point of given audio files latent representations
latents = []
slices_dir = 'slices'
for slice_file in os.listdir(slices_dir):
if slice_file.endswith('.wav'): # make sure the file is audio
mel = Mel(os.path.join(slices_dir, slice_file))
spectrogram = mel.get_spectrogram()
tensor = torch.tensor(spectrogram).float().unsqueeze(0).unsqueeze(0)
mu, log_var = vae.encode(tensor)
latent = torch.cat((mu, log_var), dim=1)
min_val = latent.min()
max_val = latent.max()
normalized_tensor = 2 * ((latent - min_val) / (max_val - min_val)) - 1
latent = normalized_tensor.unsqueeze(0)
latents.append(latent)
if not latents:
return None
latents_tensor = torch.cat(latents, dim=0)
mean_latent = latents_tensor.mean(dim=0, keepdim=True)
return mean_latent
interface = gr.Interface(
fn=generate_new_track,
inputs=gr.Files(file_count="multiple", label="Upload Your Audio Files"),
outputs=gr.Audio(type="filepath", label="Generated Track"),
title="AMUSE: Music Generation",
description = (
"<h3>Welcome to the AMUSE music generation app</h3>"
"<p>Here's how it works:</p>"
"<ol>"
"<li><strong>Upload Your Audio Files:</strong> Provide audio files from which the taste will be extracted, "
"and a new track will be generated accordingly. The audio files should be in .wav format!</li>"
"<li><strong>Process:</strong> The app slices the audio, extracts features, and generates a new track using a VAE and a diffusion model.</li>"
"<li><strong>Progress:</strong> The progress bar will show the generation process in real-time. Note that this takes a significant amount of time, "
"so you may leave the site in the free version and come back later to see the result.</li>"
"<li><strong>Download:</strong> Once the track is generated, you can download it directly.</li>"
"</ol>"
"<h4>Notes:</h4>"
"<ul>"
"<li>As mentioned earlier, it takes a significant amount of time to generate a new track in the free version of HF Spaces. "
"So, submit your tracks and forget about it for a little while :) Then come back to see the new track.</li>"
"<li>Ensure your audio files are clean and of good quality for the best results (sample rate: 44100 and .wav format).</li>"
"</ul>"
)
)
interface.launch()