|
import os |
|
import sys |
|
import numpy as np |
|
import torch |
|
import gradio as gr |
|
from vae_module import VAE, Encoder, Decoder, loss_function |
|
from config import config |
|
from slicer_module import get_slices |
|
from diffusers import UNet2DConditionModel, DDPMScheduler |
|
from mel_module import Mel |
|
from generator_module import Generator |
|
import shutil |
|
|
|
slices_folder = 'slices' |
|
|
|
if os.path.exists(slices_folder): |
|
shutil.rmtree(slices_folder) |
|
|
|
vae = VAE() |
|
vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu'))) |
|
vae.to(config.device) |
|
vae.eval() |
|
|
|
model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet") |
|
noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler") |
|
|
|
def generate_new_track(audio_paths, progress=gr.Progress(track_tqdm=True)): |
|
for i, audio_path in enumerate(audio_paths): |
|
print(audio_paths, audio_path) |
|
get_slices(audio_path) |
|
|
|
embedding = get_embedding() |
|
print("sample latent", embedding.shape) |
|
|
|
generator = Generator(config, model, noise_scheduler, vae, embedding, progress_callback=progress) |
|
generator.generate() |
|
|
|
return config.generated_track_path |
|
|
|
def get_embedding(): |
|
latents = [] |
|
slices_dir = 'slices' |
|
|
|
for slice_file in os.listdir(slices_dir): |
|
if slice_file.endswith('.wav'): |
|
mel = Mel(os.path.join(slices_dir, slice_file)) |
|
spectrogram = mel.get_spectrogram() |
|
tensor = torch.tensor(spectrogram).float().unsqueeze(0).unsqueeze(0) |
|
mu, log_var = vae.encode(tensor) |
|
latent = torch.cat((mu, log_var), dim=1) |
|
min_val = latent.min() |
|
max_val = latent.max() |
|
normalized_tensor = 2 * ((latent - min_val) / (max_val - min_val)) - 1 |
|
latent = normalized_tensor.unsqueeze(0) |
|
latents.append(latent) |
|
|
|
if not latents: |
|
return None |
|
|
|
latents_tensor = torch.cat(latents, dim=0) |
|
mean_latent = latents_tensor.mean(dim=0, keepdim=True) |
|
return mean_latent |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_new_track, |
|
inputs=gr.Files(file_count="multiple", label="Upload Your Audio Files"), |
|
outputs=gr.Audio(type="filepath", label="Generated Track"), |
|
title="AMUSE: Music Generation", |
|
description = ( |
|
"<h3>Welcome to the AMUSE music generation app</h3>" |
|
"<p>Here's how it works:</p>" |
|
"<ol>" |
|
"<li><strong>Upload Your Audio Files:</strong> Provide audio files from which the taste will be extracted, " |
|
"and a new track will be generated accordingly. The audio files should be in .wav format!</li>" |
|
"<li><strong>Process:</strong> The app slices the audio, extracts features, and generates a new track using a VAE and a diffusion model.</li>" |
|
"<li><strong>Progress:</strong> The progress bar will show the generation process in real-time. Note that this takes a significant amount of time, " |
|
"so you may leave the site in the free version and come back later to see the result.</li>" |
|
"<li><strong>Download:</strong> Once the track is generated, you can download it directly.</li>" |
|
"</ol>" |
|
"<h4>Notes:</h4>" |
|
"<ul>" |
|
"<li>As mentioned earlier, it takes a significant amount of time to generate a new track in the free version of HF Spaces. " |
|
"So, submit your tracks and forget about it for a little while :) Then come back to see the new track.</li>" |
|
"<li>Ensure your audio files are clean and of good quality for the best results (sample rate: 44100 and .wav format).</li>" |
|
"</ul>" |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
interface.launch() |
|
|