import gradio as gr import pretty_midi import matplotlib.pyplot as plt import numpy as np import soundfile as sf import cv2 import imageio import sys import subprocess import os import torch from model import init_ldm_model from model.model_sdf import Diffpro_SDF from model.sampler_sdf import SDFSampler import pickle from train.train_params import params_chord_lsh_cond from generation.gen_utils import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt' chord_list = list(CHORD_DICTIONARY.keys()) def get_shape(file_path): if file_path.endswith('.jpg'): img = cv2.imread(file_path) return img.shape # (height, width, channels) elif file_path.endswith('.mp4'): vid = imageio.get_reader(file_path) return vid.get_meta_data()['size'] # (width, height) else: raise ValueError("Unsupported file type") # Function to convert MIDI to WAV def midi_to_wav(midi, output_file): # Synthesize the waveform from the MIDI using pretty_midi audio_data = midi.fluidsynth() # Write the waveform to a WAV file sf.write(output_file, audio_data, samplerate=44100) def update_musescore_image(selected_prompt): # Logic to return the correct image file based on the selected prompt if selected_prompt == "example 1": return "samples/diy_examples/example1/example1.jpg" elif selected_prompt == "example 2": return "samples/diy_examples/example2/example2.jpg" elif selected_prompt == "example 3": return "samples/diy_examples/example3/example3.jpg" elif selected_prompt == "example 4": return "samples/diy_examples/example4/example4.jpg" elif selected_prompt == "example 5": return "samples/diy_examples/example5/example5.jpg" elif selected_prompt == "example 6": return "samples/diy_examples/example6/example6.jpg" # Model for generating music def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"): ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False) model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device) sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False) if mode=="example": if prompt == "example 1": background_condition = np.load("samples/diy_examples/example1/example1.npy") tempo=70 elif prompt == "example 2": background_condition = np.load("samples/diy_examples/example2/example2.npy") elif prompt == "example 3": background_condition = np.load("samples/diy_examples/example3/example3.npy") elif prompt == "example 4": background_condition = np.load("samples/diy_examples/example4/example4.npy") background_condition = np.tile(background_condition, (num_samples,1,1,1)) background_condition = torch.Tensor(background_condition).to(device) else: background_condition = np.tile(prompt, (num_samples,1,1,1)) background_condition = torch.Tensor(background_condition).to(device) if rhythm_control!="Yes": background_condition[:,0:2] = background_condition[:,2:4] # generate samples output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples, same_noise_all_measure=False, X0EditFunc=X0EditFunc, use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False, rhythm_control=rhythm_control) output_x = torch.clamp(output_x, min=0, max=1) output_x = output_x.cpu().numpy() # save samples for i in range(num_samples): full_roll = extend_piano_roll(output_x[i]) # accompaniment roll full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll full_lsh_roll = None if background_condition.shape[1]>=6: if background_condition[:,4:6,:,:].min()>=0: full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy()) midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo) filename = f"output_{i}.mid" save_midi(midi_file, filename) subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate() return 'output_0.mid', 'output_0.wav', midi_file # Function to visualize MIDI notes def visualize_midi(midi): # Get piano roll from MIDI roll = midi.get_piano_roll(fs=100) # Plot the piano roll plt.figure(figsize=(10, 4)) plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest') plt.title("Piano Roll") plt.xlabel("Time") plt.ylabel("Pitch") plt.colorbar() # Save the plot as an image output_image_path = "piano_roll.png" plt.savefig(output_image_path) return output_image_path # Gradio main function def generate_from_example(prompt): midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control="No") piano_roll_image = visualize_midi(midi) return audio_output, piano_roll_image # Prompt list prompt_list = ["example 1", "example 2", "example 3", "example 4"] custom_css = """ .custom-purple { background-color: #d7bde2; padding: 10px; border-radius: 5px; } .audio_waveform-container { display: none !important; } """ with gr.Blocks(css=custom_css) as demo: gr.Markdown("#
Efficient Fine-Grained Guidance for Diffusion Model Based Symbolic Music Generation
") gr.Markdown("
Tingyu Zhu*, Haoyu Liu*, Ziyu Wang, Zhimin Jiang, Zeyu Zheng
") gr.Markdown("
[Paper] [Code Repo]
") gr.Markdown(" For detailed information and demonstrations of our method, please visit our [GitHub Pages site](https://huajianduzhuo-code.github.io/FGG-diffusion-music/) to explore:\ \n   1. Accompaniment Generation given Melody and Chord\ \n   2. Style-Controlled Music Generation\ \n   3. Demonstrating the Effectiveness of Sampling Control by Comparison") gr.HTML("
") gr.Markdown("\n\n\n") gr.Markdown("# Interactive Demo ") gr.Markdown( "" "🎵 Try out our interactive tool to generate music with our model!
" "You can create new accompaniments conditioned on a given melody and chord progression." "
" ) gr.Markdown( "" "⚠️ This Space currently runs on a Hugging Face-provided CPU. On average, it takes ~15 seconds to generate a 4-measure music segment.
" "If multiple users are generating at the same time, you may enter a queue, which can cause delays.

" "🚀 On our local server (NVIDIA RTX 6000 Ada GPU), the same generation takes only 0.4 seconds.

" "To speed things up, you can:
" "• 🔁 Fork this Space and select a different hardware configuration
" "• 🧑‍💻 Clone our [Code Repo] and run the generation notebooks locally after installing dependencies and downloading the model weights." "
" ) with gr.Column(elem_classes="custom-purple"): gr.Markdown("### Select an example to generate music given melody and chord condition") with gr.Row(): with gr.Column(): prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1") gr.Markdown("### This is the melody to be conditioned on:") condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition") prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore) with gr.Column(): generate_button = gr.Button("Generate") gr.Markdown("### Generation results:") audio_output = gr.Audio(label="Generated Music") piano_roll_output = gr.Image(label="Generated Piano Roll") generate_button.click( fn=generate_from_example, inputs=[prompt_selector], outputs=[audio_output, piano_roll_output] ) # Launch Gradio interface if __name__ == "__main__": demo.launch()