|
|
import torchaudio |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
|
import torchaudio.transforms as T |
|
|
|
|
|
|
|
|
model_options = { |
|
|
"Small Model": "facebook/musicgen-small", |
|
|
"Medium Model": "facebook/musicgen-medium", |
|
|
"Large Model": "facebook/musicgen-large" |
|
|
} |
|
|
|
|
|
|
|
|
style_tags_options = ["East Coast", "Trap", "Boom Bap", "Lo-Fi", "Experimental", |
|
|
"Rock", "Electronic", "Pop", "Country", "Heavy Metal", |
|
|
"Classical", "Jazz", "Reggae"] |
|
|
|
|
|
def generate_spectrogram(audio_tensor, sample_rate): |
|
|
griffinlim_transform = T.GriffinLim(n_fft=400, win_length=400, hop_length=160) |
|
|
waveform = griffinlim_transform(audio_tensor) |
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
|
plt.specgram(waveform.numpy()[0], Fs=sample_rate, cmap='viridis') |
|
|
plt.colorbar(format='%+2.0f dB') |
|
|
plt.title('Spectrogram') |
|
|
plt.tight_layout() |
|
|
plt.ylabel('Frequency (Hz)') |
|
|
plt.xlabel('Time (s)') |
|
|
spectrogram_path = "generated_spectrogram.png" |
|
|
plt.savefig(spectrogram_path) |
|
|
plt.close() |
|
|
return spectrogram_path |
|
|
|
|
|
def generate_music(description, model_choice, style_tags, tempo, intensity, duration): |
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained(model_options[model_choice]) |
|
|
model = MusicgenForConditionalGeneration.from_pretrained(model_options[model_choice]) |
|
|
|
|
|
|
|
|
style_tags_str = " ".join(style_tags) |
|
|
|
|
|
inputs = processor(text=[description + " " + style_tags_str], return_tensors="pt", padding=True) |
|
|
audio_output = model.generate(**inputs, max_new_tokens=256) |
|
|
|
|
|
sampling_rate = 16000 |
|
|
output_file = "generated_music.wav" |
|
|
torchaudio.save(output_file, audio_output[0].cpu(), sampling_rate) |
|
|
spectrogram_path = generate_spectrogram(audio_output[0].squeeze(), sampling_rate) |
|
|
|
|
|
return output_file, spectrogram_path, None |
|
|
except Exception as e: |
|
|
error_message = f"An error occurred: {str(e)}" |
|
|
return None, None, error_message |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=generate_music, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Enter a description for the music"), |
|
|
gr.Dropdown(label="Select Model", choices=list(model_options.keys())), |
|
|
gr.CheckboxGroup(label="Style Tags", choices=style_tags_options), |
|
|
gr.Slider(label="Tempo", minimum=60, maximum=240, step=1, value=120), |
|
|
gr.Slider(label="Intensity", minimum=1, maximum=10, step=1, value=5), |
|
|
gr.Slider(label="Duration (Seconds)", minimum=15, maximum=300, step=1, value=60) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Audio(label="Generated Music"), |
|
|
gr.Image(label="Spectrogram"), |
|
|
gr.Textbox(label="Error Message", visible=True) |
|
|
], |
|
|
title="MusicGen Pro XL", |
|
|
description="Generate original music from multiple genres with customizable parameters and style tags. Listen to the generated music, visualize the spectrogram, and receive error messages if any." |
|
|
) |
|
|
|
|
|
iface.launch() |
|
|
|