Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torchaudio | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
import tempfile | |
import os | |
import torch | |
from gradio_client import Client, handle_file | |
import random | |
import time | |
import io | |
from pydub import AudioSegment | |
# Check if CUDA is available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# MelodyFlow variation mapping - map your semantic variations to text prompts | |
VARIATION_PROMPTS = { | |
'accordion_folk': 'folk accordion melody with traditional folk instruments', | |
'banjo_bluegrass': 'bluegrass banjo with country folk instruments', | |
'piano_classical': 'classical piano with orchestral arrangement', | |
'celtic': 'celtic harp and flute with traditional irish instruments', | |
'strings_quartet': 'string quartet with violin, viola, cello arrangement', | |
'synth_retro': 'retro 80s synthesizer with vintage electronic sounds', | |
'synth_modern': 'modern synthesizer with contemporary electronic production', | |
'synth_edm': 'edm synthesizer with dance electronic beats', | |
'lofi_chill': 'lo-fi chill with relaxed jazz hip-hop elements', | |
'synth_bass': 'heavy bass synthesizer with sub-bass frequencies', | |
'rock_band': 'rock band with electric guitar, bass, and drums', | |
'cinematic_epic': 'cinematic epic orchestral with dramatic strings and brass', | |
'retro_rpg': 'retro rpg chiptune with 8-bit game music elements', | |
'chiptune': '8-bit chiptune with retro video game sounds', | |
'steel_drums': 'steel drums with caribbean tropical percussion', | |
'gamelan_fusion': 'gamelan fusion with indonesian percussion instruments', | |
'music_box': 'music box with delicate mechanical melody', | |
'trap_808': 'trap beats with heavy 808 drums and hi-hats', | |
'lo_fi_drums': 'lo-fi drums with vinyl crackle and jazz samples', | |
'boom_bap': 'boom bap hip-hop with classic drum breaks', | |
'percussion_ensemble': 'percussion ensemble with varied drum instruments', | |
'future_bass': 'future bass with melodic drops and vocal chops', | |
'synthwave_retro': 'synthwave retro with neon 80s aesthetic', | |
'melodic_techno': 'melodic techno with driving beats and emotional melodies', | |
'dubstep_wobble': 'dubstep with heavy wobble bass and electronic drops', | |
'glitch_hop': 'glitch hop with broken beats and digital artifacts', | |
'digital_disruption': 'digital disruption with glitchy electronic effects', | |
'circuit_bent': 'circuit bent with broken electronic hardware sounds', | |
'orchestral_glitch': 'orchestral glitch with classical instruments and digital errors', | |
'vapor_drums': 'vaporwave drums with slowed down nostalgic beats', | |
'industrial_textures': 'industrial textures with harsh mechanical sounds', | |
'jungle_breaks': 'jungle breaks with fast drum and bass rhythms' | |
} | |
def preprocess_audio(waveform): | |
waveform_np = waveform.cpu().squeeze().numpy() | |
return torch.from_numpy(waveform_np).unsqueeze(0).to(device) | |
# ========== MUSICGEN FUNCTIONS (Local ZeroGPU) ========== | |
def generate_drum_sample(): | |
model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle') | |
model.set_generation_params(duration=10) | |
wav = model.generate_unconditional(1).squeeze(0) | |
filename_without_extension = f'jungle' | |
filename_with_extension = f'{filename_without_extension}.wav' | |
audio_write(filename_without_extension, wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) | |
return filename_with_extension | |
def continue_drum_sample(existing_audio_path): | |
if existing_audio_path is None: | |
return None | |
existing_audio, sr = torchaudio.load(existing_audio_path) | |
existing_audio = existing_audio.to(device) | |
prompt_duration = 2 | |
output_duration = 10 | |
num_samples = int(prompt_duration * sr) | |
if existing_audio.shape[1] < num_samples: | |
raise ValueError("The existing audio is too short for the specified prompt duration.") | |
start_sample = existing_audio.shape[1] - num_samples | |
prompt_waveform = existing_audio[..., start_sample:] | |
model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle') | |
model.set_generation_params(duration=output_duration) | |
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
output = output.to(device) | |
if output.dim() == 3: | |
output = output.squeeze(0) | |
if output.dim() == 1: | |
output = output.unsqueeze(0) | |
combined_audio = torch.cat((existing_audio, output), dim=1) | |
combined_audio = combined_audio.cpu() | |
combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav' | |
torchaudio.save(combined_file_path, combined_audio, sr) | |
return combined_file_path | |
def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration): | |
"""Generate music using the BEGINNING of the audio as prompt""" | |
if wav_filename is None: | |
return None | |
song, sr = torchaudio.load(wav_filename) | |
song = song.to(device) | |
model_name = musicgen_model.split(" ")[0] | |
model_continue = MusicGen.get_pretrained(model_name) | |
model_continue.set_generation_params( | |
use_sampling=True, | |
top_k=250, | |
top_p=0.0, | |
temperature=1.0, | |
duration=output_duration, | |
cfg_coef=3 | |
) | |
prompt_waveform = song[..., :int(prompt_duration * sr)] | |
prompt_waveform = preprocess_audio(prompt_waveform) | |
output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
output = output.cpu() | |
if len(output.size()) > 2: | |
output = output.squeeze() | |
filename_without_extension = f'continued_music' | |
filename_with_extension = f'{filename_without_extension}.wav' | |
audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True) | |
return filename_with_extension | |
def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration): | |
"""Continue music using the END of the audio as prompt - extends the audio""" | |
if input_audio_path is None: | |
return None | |
song, sr = torchaudio.load(input_audio_path) | |
song = song.to(device) | |
model_name = musicgen_model.split(" ")[0] | |
model_continue = MusicGen.get_pretrained(model_name) | |
model_continue.set_generation_params( | |
use_sampling=True, | |
top_k=250, | |
top_p=0.0, | |
temperature=1.0, | |
duration=output_duration, | |
cfg_coef=3 | |
) | |
# Load original audio as AudioSegment for easier manipulation | |
original_audio = AudioSegment.from_wav(input_audio_path) | |
file_paths_for_cleanup = [] | |
# Get the last `prompt_duration` seconds as the prompt | |
num_samples = int(prompt_duration * sr) | |
if song.shape[1] < num_samples: | |
raise ValueError("The prompt_duration is longer than the current audio length.") | |
# Extract the end portion for prompting | |
start_sample = song.shape[1] - num_samples | |
prompt_waveform = song[..., start_sample:] | |
prompt_waveform = preprocess_audio(prompt_waveform) | |
# Generate continuation | |
output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
output = output.cpu() | |
if len(output.size()) > 2: | |
output = output.squeeze() | |
# Save the generated audio WITHOUT aggressive loudness processing | |
filename_without_extension = f'continue_extension_{random.randint(1000, 9999)}' | |
filename_with_extension = f'{filename_without_extension}.wav' | |
audio_write(filename_without_extension, output, model_continue.sample_rate, | |
strategy="clip") # Just prevent clipping, no loudness changes | |
# Handle the double .wav extension issue | |
correct_filename = f'{filename_without_extension}.wav.wav' | |
if os.path.exists(correct_filename): | |
generated_audio_segment = AudioSegment.from_wav(correct_filename) | |
file_paths_for_cleanup.append(correct_filename) | |
else: | |
generated_audio_segment = AudioSegment.from_wav(filename_with_extension) | |
file_paths_for_cleanup.append(filename_with_extension) | |
# SMART VOLUME MATCHING: Only match the prompt portion | |
# 1. Remove prompt duration from original (no overlap) | |
prompt_duration_ms = int(prompt_duration * 1000) | |
original_minus_prompt = original_audio[:-prompt_duration_ms] | |
# 2. Extract JUST the prompt portion from generated audio for RMS analysis | |
generated_prompt_portion = generated_audio_segment[:prompt_duration_ms] | |
# 3. Calculate RMS of the transition points | |
original_rms = original_minus_prompt.rms | |
prompt_portion_rms = generated_prompt_portion.rms | |
print(f"π Smart volume analysis:") | |
print(f" Original ending RMS: {original_rms}") | |
print(f" Generated prompt RMS: {prompt_portion_rms}") | |
print(f" Generated full RMS: {generated_audio_segment.rms}") | |
# 4. Match the prompt portion to original level | |
if prompt_portion_rms > 0: | |
from pydub.utils import ratio_to_db | |
volume_adjustment = ratio_to_db(original_rms / prompt_portion_rms) | |
print(f" Applying {volume_adjustment:.1f}dB to entire generated segment") | |
# Apply to entire segment (preserves the buildup) | |
generated_matched = generated_audio_segment + volume_adjustment | |
else: | |
generated_matched = generated_audio_segment | |
# 5. Combine seamlessly | |
combined_audio = original_minus_prompt + generated_matched | |
# Save final result | |
combined_audio_filename = f"extended_audio_{random.randint(1000, 9999)}.wav" | |
combined_audio.export(combined_audio_filename, format="wav") | |
# Cleanup temporary files | |
for file_path in file_paths_for_cleanup: | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
return combined_audio_filename | |
# ========== MELODYFLOW FUNCTIONS (Via Facebook Space) ========== | |
def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solver="euler", flowstep=0.12): | |
"""Transform audio using Facebook/MelodyFlow space API""" | |
if audio_path is None: | |
return None, "β No audio file provided" | |
base_steps = 125 | |
effective_steps = 25 | |
try: | |
# Initialize client for Facebook MelodyFlow space | |
client = Client("facebook/MelodyFlow") | |
# Determine the prompt to use | |
if custom_prompt.strip(): | |
prompt_text = custom_prompt.strip() | |
status_msg = f"β Transformed with custom prompt: '{prompt_text}' (flowstep: {flowstep}, {effective_steps} steps)" | |
else: | |
prompt_text = VARIATION_PROMPTS.get(variation, f"transform this audio to {variation} style") | |
status_msg = f"β Transformed with {variation} style (flowstep: {flowstep}, {effective_steps} steps)" | |
# Set steps based on solver and the fact we're doing editing | |
# Facebook's space automatically reduces steps for editing: | |
# EULER: divides by 5, MIDPOINT: divides by 2 | |
if solver == "midpoint": | |
base_steps = 128 | |
effective_steps = base_steps // 2 # 64 effective steps | |
else: # euler | |
base_steps = 125 | |
effective_steps = base_steps // 5 # 25 effective steps | |
# Call the MelodyFlow API with the base steps (it will auto-reduce) | |
result = client.predict( | |
model="facebook/melodyflow-t24-30secs", | |
text=prompt_text, | |
solver=solver, | |
steps=base_steps, # Will be auto-reduced to effective_steps by the space | |
target_flowstep=flowstep, # This is the key parameter! | |
regularize=solver == "euler", # Regularize for euler, not for midpoint | |
regularization_strength=0.2, | |
duration=30, # Max duration | |
melody=handle_file(audio_path), | |
api_name="/predict" | |
) | |
# Result is a tuple of 3 audio files (variations) | |
# We'll use the first variation | |
if result and len(result) > 0 and result[0]: | |
# Save the result locally | |
output_filename = f"melodyflow_{variation}_{random.randint(1000, 9999)}.wav" | |
# Copy the result file to our local filename | |
import shutil | |
shutil.copy2(result[0], output_filename) | |
return output_filename, status_msg | |
else: | |
return None, "β MelodyFlow API returned no results" | |
except Exception as e: | |
return None, f"β MelodyFlow API error: {str(e)}" | |
# ========== GRADIO INTERFACE ========== | |
# Create the interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# π° The Mega Slot Machine") | |
gr.Markdown("**Hybrid Multi-Model Pipeline**: MicroMusicGen β MelodyFlow (via API) β MusicGen Fine-tunes") | |
gr.Markdown("*Demonstrating the workflow from our Ableton device in a web interface!*") | |
with gr.Accordion("How This Works", open=False): | |
gr.Markdown(""" | |
This demo shows how multiple AI models can work together: | |
1. **Generate** initial audio with MicroMusicGen (super fast jungle drums) | |
2. **Transform** it using MelodyFlow (via Facebook's space API) | |
3. **Continue** with MusicGen fine-tunes (trained on specific genres) | |
4. **Repeat** the cycle to create infinite musical journeys! | |
The models run with different PyTorch versions, so we use the Facebook MelodyFlow space via API. | |
**Performance Note**: For audio transformation, MelodyFlow automatically uses fewer steps than generation: | |
- EULER solver: 25 effective steps (fast, good quality) | |
- MIDPOINT solver: 64 effective steps (slower, potentially higher quality) | |
""") | |
# ========== STEP 1: GENERATE ========== | |
gr.Markdown("## π΅ Step 1: Generate Initial Audio") | |
with gr.Row(): | |
with gr.Column(): | |
generate_button = gr.Button("Generate Jungle Drums", variant="primary", size="lg") | |
continue_drum_button = gr.Button("Continue Drums", size="sm") | |
main_audio = gr.Audio( | |
label="π΅ Current Audio (flows through pipeline)", | |
type="filepath", | |
interactive=True, | |
show_download_button=True | |
) | |
# ========== STEP 2: TRANSFORM ========== | |
gr.Markdown("## ποΈ Step 2: Transform with MelodyFlow") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
transform_variation = gr.Dropdown( | |
label="Transform Style", | |
choices=list(VARIATION_PROMPTS.keys()), | |
value="synth_modern", | |
interactive=True | |
) | |
with gr.Column(scale=3): | |
transform_prompt = gr.Textbox( | |
label="Custom Prompt (optional)", | |
placeholder="Leave empty to use style above, or enter custom transformation prompt", | |
lines=2 | |
) | |
with gr.Row(): | |
transform_solver = gr.Dropdown( | |
label="Solver", | |
choices=["euler", "midpoint"], | |
value="euler", | |
info="EULER: faster (25 steps), MIDPOINT: slower but potentially higher quality (64 steps)" | |
) | |
transform_flowstep = gr.Slider( | |
label="Transform Intensity (Flowstep)", | |
minimum=0.0, | |
maximum=0.15, | |
step=0.01, | |
value=0.12, | |
info="Lower values = more dramatic transformation" | |
) | |
transform_button = gr.Button("ποΈ Transform Audio", variant="secondary", size="lg") | |
transform_status = gr.Textbox(label="Transform Status", value="Ready to transform", interactive=False) | |
# ========== STEP 3: CONTINUE ========== | |
gr.Markdown("## πΌ Step 3: Continue with MusicGen") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_duration = gr.Dropdown( | |
label="Prompt Duration (seconds)", | |
choices=list(range(1, 11)), | |
value=5 | |
) | |
output_duration = gr.Slider( | |
label="Output Duration (seconds)", | |
minimum=10, | |
maximum=30, | |
step=1, | |
value=20 | |
) | |
with gr.Column(): | |
musicgen_model = gr.Dropdown( | |
label="MusicGen Model", | |
choices=[ | |
"thepatch/vanya_ai_dnb_0.1 (small)", | |
"thepatch/budots_remix (small)", | |
"thepatch/PhonkV2 (small)", | |
"thepatch/bleeps-medium (medium)", | |
"thepatch/hoenn_lofi (large)", | |
"foureyednymph/musicgen-sza-sos-small (small)" | |
], | |
value="thepatch/vanya_ai_dnb_0.1 (small)" | |
) | |
# Two different continuation options with clear explanations | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### π Continue from Beginning") | |
gr.Markdown("*Uses the **first** X seconds as prompt. Good for reimagining/reworking from a starting point.*") | |
generate_music_button = gr.Button("π Continue from Beginning", variant="primary", size="lg") | |
with gr.Column(): | |
gr.Markdown("### β‘οΈ Extend from End") | |
gr.Markdown("*Uses the **last** X seconds as prompt. Extends your audio by adding new content to the end.*") | |
continue_music_button = gr.Button("β‘οΈ Extend from End", variant="secondary", size="lg") | |
# ========== EVENT HANDLERS ========== | |
# Step 1: Generate | |
generate_button.click(generate_drum_sample, outputs=[main_audio]) | |
continue_drum_button.click(continue_drum_sample, inputs=[main_audio], outputs=[main_audio]) | |
# Step 2: Transform (using Facebook MelodyFlow API) | |
transform_button.click( | |
transform_with_melodyflow_api, | |
inputs=[main_audio, transform_variation, transform_prompt, transform_solver, transform_flowstep], | |
outputs=[main_audio, transform_status] | |
) | |
# Step 3: Continue (two different approaches) | |
generate_music_button.click( | |
generate_music, | |
inputs=[main_audio, prompt_duration, musicgen_model, output_duration], | |
outputs=[main_audio] | |
) | |
continue_music_button.click( | |
continue_music, | |
inputs=[main_audio, prompt_duration, musicgen_model, output_duration], | |
outputs=[main_audio] | |
) | |
if __name__ == "__main__": | |
iface.launch() |