import streamlit as st import torch import numpy as np import matplotlib.pyplot as plt from pydub import AudioSegment import pretty_midi as pm from VAE import VAE from midi2audio import FluidSynth import pretty_midi as pm # Define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load VAE model @st.cache_resource def load_model(): vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256) vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device)) vae = vae.to(device) vae.eval() return vae # Function to process the uploaded MIDI file def process_midi(file): try: mid = pm.PrettyMIDI(file) fs = 10 hand_dict = {"right": None, "left": None} pitch_list = [] for inst in mid.instruments: if inst.program // 8 > 0: continue piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) if np.sum(piano_roll) == 0: continue hand_pitch = np.where(piano_roll) pitch_list.append(np.average(hand_pitch[0])) if len(pitch_list) == 0: st.error("No valid piano data found.") return None, None elif len(pitch_list) == 1: hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) hand_dict['left'] = np.zeros_like(hand_dict['right']) else: hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) pitch_start, pitch_stop, duration = 24, 100, 150 right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration] left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration] if np.sum(right_hand) == 0 or np.sum(left_hand) == 0: st.error("Invalid data detected in MIDI file.") return None, None return right_hand, left_hand except Exception as e: st.error(f"Error processing MIDI: {e}") return None, None # Run the VAE model for reconstruction def reconstruct(right, left, model): right_tensor = torch.tensor(right, dtype=torch.float32).to(device) left_tensor = torch.tensor(left, dtype=torch.float32).to(device) input_tensor = torch.cat([right_tensor, left_tensor], dim=0) input_tensor = input_tensor.unsqueeze(0) print(input_tensor.shape) with torch.no_grad(): recon_data, _, _, _ = model(input_tensor) return recon_data.squeeze(0).cpu().numpy() def midi_to_wav(midi_file, wav_file="output.wav", sound_font_path="soundfont.sf2", volume_increase_db=17): fs = FluidSynth(sound_font_path) fs.midi_to_audio(midi_file, wav_file) audio = AudioSegment.from_wav(wav_file) louder_audio = audio + volume_increase_db louder_audio.export(wav_file, format="wav") return wav_file # Create a MIDI stream from piano roll data def create_midi_from_piano_roll(right_hand, left_hand, fs=8): pm_obj = pm.PrettyMIDI() for hand_data in [right_hand, left_hand]: instrument = pm.Instrument(program=0) # Acoustic Grand Piano pm_obj.instruments.append(instrument) for pitch, series in enumerate(hand_data): start_time = None threshold = 0.92 # Threshold for detecting note onset for j in range(len(series) - 1): if series[j] < threshold and series[j + 1] >= threshold: start_time = j / fs elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None: end_time = (j + 1) / fs if start_time is not None and end_time is not None: note = pm.Note( velocity=100, pitch=pitch + 24, start=start_time, end=end_time ) instrument.notes.append(note) start_time = None if start_time is not None: end_time = len(series) / fs note = pm.Note( velocity=100, pitch=pitch + 24, start=start_time, end=end_time ) instrument.notes.append(note) return pm_obj # Function to convert reconstructed data to MIDI files def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8): midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs) midi_data.write(file_name) print(f"MIDI file saved to {file_name}") return file_name # Streamlit interface st.title("GRU-VAE Reconstruction Demo") model = load_model() # File upload uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"]) if uploaded_file is not None: st.write("Processing MIDI file...") right_hand, left_hand = process_midi(uploaded_file) if right_hand is not None and left_hand is not None: # Display original data st.write("Original MIDI Data:") fig, axs = plt.subplots(1, 2, figsize=(10, 4)) axs[0].imshow(right_hand, aspect='auto', cmap='gray') axs[0].set_title("Right Hand") axs[1].imshow(left_hand, aspect='auto', cmap='gray') axs[1].set_title("Left Hand") st.pyplot(fig) # Reconstruction recon_data = reconstruct(right_hand.T, left_hand.T, model) recon_right = recon_data[:150].T recon_left = recon_data[150:].T # Display reconstructed data st.write("Reconstructed MIDI Data:") fig, axs = plt.subplots(1, 2, figsize=(10, 4)) axs[0].imshow(recon_right, aspect='auto', cmap='gray') axs[0].set_title("Right Hand (Reconstructed)") axs[1].imshow(recon_left, aspect='auto', cmap='gray') axs[1].set_title("Left Hand (Reconstructed)") st.pyplot(fig) # Convert to MIDI and then to WAV for playback original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8) recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8) # Save and play audio original_wav_path = midi_to_wav(original_midi, "original_output.wav") recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav") st.write("Original MIDI Playback:") st.audio(original_wav_path, format='audio/wav') st.write("Reconstructed MIDI Playback:") st.audio(recon_wav_path, format='audio/wav')