Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torchaudio | |
| from audiocraft.models import MusicGen | |
| from audiocraft.data.audio import audio_write | |
| import tempfile | |
| import os | |
| import torch | |
| from gradio_client import Client, handle_file | |
| import random | |
| import time | |
| import io | |
| from pydub import AudioSegment | |
| # Check if CUDA is available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # MelodyFlow variation mapping - map your semantic variations to text prompts | |
| VARIATION_PROMPTS = { | |
| 'accordion_folk': 'folk accordion melody with traditional folk instruments', | |
| 'banjo_bluegrass': 'bluegrass banjo with country folk instruments', | |
| 'piano_classical': 'classical piano with orchestral arrangement', | |
| 'celtic': 'celtic harp and flute with traditional irish instruments', | |
| 'strings_quartet': 'string quartet with violin, viola, cello arrangement', | |
| 'synth_retro': 'retro 80s synthesizer with vintage electronic sounds', | |
| 'synth_modern': 'modern synthesizer with contemporary electronic production', | |
| 'synth_edm': 'edm synthesizer with dance electronic beats', | |
| 'lofi_chill': 'lo-fi chill with relaxed jazz hip-hop elements', | |
| 'synth_bass': 'heavy bass synthesizer with sub-bass frequencies', | |
| 'rock_band': 'rock band with electric guitar, bass, and drums', | |
| 'cinematic_epic': 'cinematic epic orchestral with dramatic strings and brass', | |
| 'retro_rpg': 'retro rpg chiptune with 8-bit game music elements', | |
| 'chiptune': '8-bit chiptune with retro video game sounds', | |
| 'steel_drums': 'steel drums with caribbean tropical percussion', | |
| 'gamelan_fusion': 'gamelan fusion with indonesian percussion instruments', | |
| 'music_box': 'music box with delicate mechanical melody', | |
| 'trap_808': 'trap beats with heavy 808 drums and hi-hats', | |
| 'lo_fi_drums': 'lo-fi drums with vinyl crackle and jazz samples', | |
| 'boom_bap': 'boom bap hip-hop with classic drum breaks', | |
| 'percussion_ensemble': 'percussion ensemble with varied drum instruments', | |
| 'future_bass': 'future bass with melodic drops and vocal chops', | |
| 'synthwave_retro': 'synthwave retro with neon 80s aesthetic', | |
| 'melodic_techno': 'melodic techno with driving beats and emotional melodies', | |
| 'dubstep_wobble': 'dubstep with heavy wobble bass and electronic drops', | |
| 'glitch_hop': 'glitch hop with broken beats and digital artifacts', | |
| 'digital_disruption': 'digital disruption with glitchy electronic effects', | |
| 'circuit_bent': 'circuit bent with broken electronic hardware sounds', | |
| 'orchestral_glitch': 'orchestral glitch with classical instruments and digital errors', | |
| 'vapor_drums': 'vaporwave drums with slowed down nostalgic beats', | |
| 'industrial_textures': 'industrial textures with harsh mechanical sounds', | |
| 'jungle_breaks': 'jungle breaks with fast drum and bass rhythms' | |
| } | |
| def preprocess_audio(waveform): | |
| waveform_np = waveform.cpu().squeeze().numpy() | |
| return torch.from_numpy(waveform_np).unsqueeze(0).to(device) | |
| # ========== MUSICGEN FUNCTIONS (Local ZeroGPU) ========== | |
| def generate_drum_sample(): | |
| model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle') | |
| model.set_generation_params(duration=10) | |
| wav = model.generate_unconditional(1).squeeze(0) | |
| filename_without_extension = f'jungle' | |
| filename_with_extension = f'{filename_without_extension}.wav' | |
| audio_write(filename_without_extension, wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) | |
| return filename_with_extension | |
| def continue_drum_sample(existing_audio_path): | |
| if existing_audio_path is None: | |
| return None | |
| existing_audio, sr = torchaudio.load(existing_audio_path) | |
| existing_audio = existing_audio.to(device) | |
| prompt_duration = 2 | |
| output_duration = 10 | |
| num_samples = int(prompt_duration * sr) | |
| if existing_audio.shape[1] < num_samples: | |
| raise ValueError("The existing audio is too short for the specified prompt duration.") | |
| start_sample = existing_audio.shape[1] - num_samples | |
| prompt_waveform = existing_audio[..., start_sample:] | |
| model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle') | |
| model.set_generation_params(duration=output_duration) | |
| output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
| output = output.to(device) | |
| if output.dim() == 3: | |
| output = output.squeeze(0) | |
| if output.dim() == 1: | |
| output = output.unsqueeze(0) | |
| combined_audio = torch.cat((existing_audio, output), dim=1) | |
| combined_audio = combined_audio.cpu() | |
| combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav' | |
| torchaudio.save(combined_file_path, combined_audio, sr) | |
| return combined_file_path | |
| def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration): | |
| """Generate music using the BEGINNING of the audio as prompt""" | |
| if wav_filename is None: | |
| return None | |
| song, sr = torchaudio.load(wav_filename) | |
| song = song.to(device) | |
| model_name = musicgen_model.split(" ")[0] | |
| model_continue = MusicGen.get_pretrained(model_name) | |
| model_continue.set_generation_params( | |
| use_sampling=True, | |
| top_k=250, | |
| top_p=0.0, | |
| temperature=1.0, | |
| duration=output_duration, | |
| cfg_coef=3 | |
| ) | |
| prompt_waveform = song[..., :int(prompt_duration * sr)] | |
| prompt_waveform = preprocess_audio(prompt_waveform) | |
| output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
| output = output.cpu() | |
| if len(output.size()) > 2: | |
| output = output.squeeze() | |
| filename_without_extension = f'continued_music' | |
| filename_with_extension = f'{filename_without_extension}.wav' | |
| audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True) | |
| return filename_with_extension | |
| def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration): | |
| """Continue music using the END of the audio as prompt - extends the audio""" | |
| if input_audio_path is None: | |
| return None | |
| song, sr = torchaudio.load(input_audio_path) | |
| song = song.to(device) | |
| model_name = musicgen_model.split(" ")[0] | |
| model_continue = MusicGen.get_pretrained(model_name) | |
| model_continue.set_generation_params( | |
| use_sampling=True, | |
| top_k=250, | |
| top_p=0.0, | |
| temperature=1.0, | |
| duration=output_duration, | |
| cfg_coef=3 | |
| ) | |
| # Load original audio as AudioSegment for easier manipulation | |
| original_audio = AudioSegment.from_wav(input_audio_path) | |
| file_paths_for_cleanup = [] | |
| # Get the last `prompt_duration` seconds as the prompt | |
| num_samples = int(prompt_duration * sr) | |
| if song.shape[1] < num_samples: | |
| raise ValueError("The prompt_duration is longer than the current audio length.") | |
| # Extract the end portion for prompting | |
| start_sample = song.shape[1] - num_samples | |
| prompt_waveform = song[..., start_sample:] | |
| prompt_waveform = preprocess_audio(prompt_waveform) | |
| # Generate continuation | |
| output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True) | |
| output = output.cpu() | |
| if len(output.size()) > 2: | |
| output = output.squeeze() | |
| # Save the generated audio WITHOUT aggressive loudness processing | |
| filename_without_extension = f'continue_extension_{random.randint(1000, 9999)}' | |
| filename_with_extension = f'{filename_without_extension}.wav' | |
| audio_write(filename_without_extension, output, model_continue.sample_rate, | |
| strategy="clip") # Just prevent clipping, no loudness changes | |
| # Handle the double .wav extension issue | |
| correct_filename = f'{filename_without_extension}.wav.wav' | |
| if os.path.exists(correct_filename): | |
| generated_audio_segment = AudioSegment.from_wav(correct_filename) | |
| file_paths_for_cleanup.append(correct_filename) | |
| else: | |
| generated_audio_segment = AudioSegment.from_wav(filename_with_extension) | |
| file_paths_for_cleanup.append(filename_with_extension) | |
| # SMART VOLUME MATCHING: Only match the prompt portion | |
| # 1. Remove prompt duration from original (no overlap) | |
| prompt_duration_ms = int(prompt_duration * 1000) | |
| original_minus_prompt = original_audio[:-prompt_duration_ms] | |
| # 2. Extract JUST the prompt portion from generated audio for RMS analysis | |
| generated_prompt_portion = generated_audio_segment[:prompt_duration_ms] | |
| # 3. Calculate RMS of the transition points | |
| original_rms = original_minus_prompt.rms | |
| prompt_portion_rms = generated_prompt_portion.rms | |
| print(f"π Smart volume analysis:") | |
| print(f" Original ending RMS: {original_rms}") | |
| print(f" Generated prompt RMS: {prompt_portion_rms}") | |
| print(f" Generated full RMS: {generated_audio_segment.rms}") | |
| # 4. Match the prompt portion to original level | |
| if prompt_portion_rms > 0: | |
| from pydub.utils import ratio_to_db | |
| volume_adjustment = ratio_to_db(original_rms / prompt_portion_rms) | |
| print(f" Applying {volume_adjustment:.1f}dB to entire generated segment") | |
| # Apply to entire segment (preserves the buildup) | |
| generated_matched = generated_audio_segment + volume_adjustment | |
| else: | |
| generated_matched = generated_audio_segment | |
| # 5. Combine seamlessly | |
| combined_audio = original_minus_prompt + generated_matched | |
| # Save final result | |
| combined_audio_filename = f"extended_audio_{random.randint(1000, 9999)}.wav" | |
| combined_audio.export(combined_audio_filename, format="wav") | |
| # Cleanup temporary files | |
| for file_path in file_paths_for_cleanup: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| return combined_audio_filename | |
| # ========== MELODYFLOW FUNCTIONS (Via Facebook Space) ========== | |
| def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solver="euler", flowstep=0.12): | |
| """Transform audio using Facebook/MelodyFlow space API""" | |
| if audio_path is None: | |
| return None, "β No audio file provided" | |
| base_steps = 125 | |
| effective_steps = 25 | |
| try: | |
| # Initialize client for Facebook MelodyFlow space | |
| client = Client("facebook/MelodyFlow") | |
| # Determine the prompt to use | |
| if custom_prompt.strip(): | |
| prompt_text = custom_prompt.strip() | |
| status_msg = f"β Transformed with custom prompt: '{prompt_text}' (flowstep: {flowstep}, {effective_steps} steps)" | |
| else: | |
| prompt_text = VARIATION_PROMPTS.get(variation, f"transform this audio to {variation} style") | |
| status_msg = f"β Transformed with {variation} style (flowstep: {flowstep}, {effective_steps} steps)" | |
| # Set steps based on solver and the fact we're doing editing | |
| # Facebook's space automatically reduces steps for editing: | |
| # EULER: divides by 5, MIDPOINT: divides by 2 | |
| if solver == "midpoint": | |
| base_steps = 128 | |
| effective_steps = base_steps // 2 # 64 effective steps | |
| else: # euler | |
| base_steps = 125 | |
| effective_steps = base_steps // 5 # 25 effective steps | |
| # Call the MelodyFlow API with the base steps (it will auto-reduce) | |
| result = client.predict( | |
| model="facebook/melodyflow-t24-30secs", | |
| text=prompt_text, | |
| solver=solver, | |
| steps=base_steps, # Will be auto-reduced to effective_steps by the space | |
| target_flowstep=flowstep, # This is the key parameter! | |
| regularize=solver == "euler", # Regularize for euler, not for midpoint | |
| regularization_strength=0.2, | |
| duration=30, # Max duration | |
| melody=handle_file(audio_path), | |
| api_name="/predict" | |
| ) | |
| # Result is a tuple of 3 audio files (variations) | |
| # We'll use the first variation | |
| if result and len(result) > 0 and result[0]: | |
| # Save the result locally | |
| output_filename = f"melodyflow_{variation}_{random.randint(1000, 9999)}.wav" | |
| # Copy the result file to our local filename | |
| import shutil | |
| shutil.copy2(result[0], output_filename) | |
| return output_filename, status_msg | |
| else: | |
| return None, "β MelodyFlow API returned no results" | |
| except Exception as e: | |
| return None, f"β MelodyFlow API error: {str(e)}" | |
| # ========== GRADIO INTERFACE ========== | |
| # Create the interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# π° The Mega Slot Machine") | |
| gr.Markdown("**Hybrid Multi-Model Pipeline**: MicroMusicGen β MelodyFlow (via API) β MusicGen Fine-tunes") | |
| gr.Markdown("*Demonstrating the workflow from our Ableton device in a web interface!*") | |
| with gr.Accordion("How This Works", open=False): | |
| gr.Markdown(""" | |
| This demo shows how multiple AI models can work together: | |
| 1. **Generate** initial audio with MicroMusicGen (super fast jungle drums) | |
| 2. **Transform** it using MelodyFlow (via Facebook's space API) | |
| 3. **Continue** with MusicGen fine-tunes (trained on specific genres) | |
| 4. **Repeat** the cycle to create infinite musical journeys! | |
| The models run with different PyTorch versions, so we use the Facebook MelodyFlow space via API. | |
| **Performance Note**: For audio transformation, MelodyFlow automatically uses fewer steps than generation: | |
| - EULER solver: 25 effective steps (fast, good quality) | |
| - MIDPOINT solver: 64 effective steps (slower, potentially higher quality) | |
| """) | |
| # ========== STEP 1: GENERATE ========== | |
| gr.Markdown("## π΅ Step 1: Generate Initial Audio") | |
| with gr.Row(): | |
| with gr.Column(): | |
| generate_button = gr.Button("Generate Jungle Drums", variant="primary", size="lg") | |
| continue_drum_button = gr.Button("Continue Drums", size="sm") | |
| main_audio = gr.Audio( | |
| label="π΅ Current Audio (flows through pipeline)", | |
| type="filepath", | |
| interactive=True, | |
| show_download_button=True | |
| ) | |
| # ========== STEP 2: TRANSFORM ========== | |
| gr.Markdown("## ποΈ Step 2: Transform with MelodyFlow") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| transform_variation = gr.Dropdown( | |
| label="Transform Style", | |
| choices=list(VARIATION_PROMPTS.keys()), | |
| value="synth_modern", | |
| interactive=True | |
| ) | |
| with gr.Column(scale=3): | |
| transform_prompt = gr.Textbox( | |
| label="Custom Prompt (optional)", | |
| placeholder="Leave empty to use style above, or enter custom transformation prompt", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| transform_solver = gr.Dropdown( | |
| label="Solver", | |
| choices=["euler", "midpoint"], | |
| value="euler", | |
| info="EULER: faster (25 steps), MIDPOINT: slower but potentially higher quality (64 steps)" | |
| ) | |
| transform_flowstep = gr.Slider( | |
| label="Transform Intensity (Flowstep)", | |
| minimum=0.0, | |
| maximum=0.15, | |
| step=0.01, | |
| value=0.12, | |
| info="Lower values = more dramatic transformation" | |
| ) | |
| transform_button = gr.Button("ποΈ Transform Audio", variant="secondary", size="lg") | |
| transform_status = gr.Textbox(label="Transform Status", value="Ready to transform", interactive=False) | |
| # ========== STEP 3: CONTINUE ========== | |
| gr.Markdown("## πΌ Step 3: Continue with MusicGen") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_duration = gr.Dropdown( | |
| label="Prompt Duration (seconds)", | |
| choices=list(range(1, 11)), | |
| value=5 | |
| ) | |
| output_duration = gr.Slider( | |
| label="Output Duration (seconds)", | |
| minimum=10, | |
| maximum=30, | |
| step=1, | |
| value=20 | |
| ) | |
| with gr.Column(): | |
| musicgen_model = gr.Dropdown( | |
| label="MusicGen Model", | |
| choices=[ | |
| "thepatch/vanya_ai_dnb_0.1 (small)", | |
| "thepatch/budots_remix (small)", | |
| "thepatch/PhonkV2 (small)", | |
| "thepatch/bleeps-medium (medium)", | |
| "thepatch/hoenn_lofi (large)", | |
| "foureyednymph/musicgen-sza-sos-small (small)" | |
| ], | |
| value="thepatch/vanya_ai_dnb_0.1 (small)" | |
| ) | |
| # Two different continuation options with clear explanations | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Continue from Beginning") | |
| gr.Markdown("*Uses the **first** X seconds as prompt. Good for reimagining/reworking from a starting point.*") | |
| generate_music_button = gr.Button("π Continue from Beginning", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.Markdown("### β‘οΈ Extend from End") | |
| gr.Markdown("*Uses the **last** X seconds as prompt. Extends your audio by adding new content to the end.*") | |
| continue_music_button = gr.Button("β‘οΈ Extend from End", variant="secondary", size="lg") | |
| # ========== EVENT HANDLERS ========== | |
| # Step 1: Generate | |
| generate_button.click(generate_drum_sample, outputs=[main_audio]) | |
| continue_drum_button.click(continue_drum_sample, inputs=[main_audio], outputs=[main_audio]) | |
| # Step 2: Transform (using Facebook MelodyFlow API) | |
| transform_button.click( | |
| transform_with_melodyflow_api, | |
| inputs=[main_audio, transform_variation, transform_prompt, transform_solver, transform_flowstep], | |
| outputs=[main_audio, transform_status] | |
| ) | |
| # Step 3: Continue (two different approaches) | |
| generate_music_button.click( | |
| generate_music, | |
| inputs=[main_audio, prompt_duration, musicgen_model, output_duration], | |
| outputs=[main_audio] | |
| ) | |
| continue_music_button.click( | |
| continue_music, | |
| inputs=[main_audio, prompt_duration, musicgen_model, output_duration], | |
| outputs=[main_audio] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |