ec98's picture
Update app.py
1b40d3e verified
raw
history blame
No virus
11.3 kB
import collections
#import datetime
#import glob
import numpy as np
#import pathlib
import pandas as pd
import pretty_midi
import seaborn as sns
from matplotlib import pyplot as plt
from typing import Optional
import tensorflow as tf
import keras
from tensorflow.keras.utils import custom_object_scope
import streamlit as st
from midi2audio import FluidSynth
import tempfile
import os
import base64
def midi_to_notes(midi_file: str) -> pd.DataFrame:
pm = pretty_midi.PrettyMIDI(midi_file)
instrument = pm.instruments[0]
notes = collections.defaultdict(list)
sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
prev_start = sorted_notes[0].start
for note in sorted_notes:
start = note.start
end = note.end
notes['pitch'].append(note.pitch)
notes['start'].append(start)
notes['end'].append(end)
notes['step'].append(start - prev_start)
notes['duration'].append(end - start)
prev_start = start
return pd.DataFrame({name: np.array(value) for name, value in notes.items()})
def notes_to_midi(
notes: pd.DataFrame,
out_file: str,
instrument_name: str,
velocity: int = 100,
) -> pretty_midi.PrettyMIDI:
pm = pretty_midi.PrettyMIDI()
instrument = pretty_midi.Instrument(
program=pretty_midi.instrument_name_to_program(
instrument_name))
prev_start = 0
for i, note in notes.iterrows():
start = float(prev_start + note['step'])
end = float(start + note['duration'])
note = pretty_midi.Note(
velocity=velocity,
pitch=int(note['pitch']),
start=start,
end=end,
)
instrument.notes.append(note)
prev_start = start
pm.instruments.append(instrument)
pm.write(out_file)
return pm
def plot_roll(notes: pd.DataFrame, count: Optional[int] = None):
if count:
title = f'First {count} notes'
else:
title = f'Whole track'
count = len(notes['pitch'])
plt.figure(figsize=(20, 4))
plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)
plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)
plt.plot(
plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
plt.xlabel('Time [s]')
plt.ylabel('Pitch')
_ = plt.title(title)
def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):
plt.figure(figsize=[15, 5])
plt.subplot(1, 3, 1)
sns.histplot(notes, x="pitch", bins=20)
plt.subplot(1, 3, 2)
max_step = np.percentile(notes['step'], 100 - drop_percentile)
sns.histplot(notes, x="step", bins=np.linspace(0, max_step, 21))
def predict_next_note(
notes: np.ndarray,
model: tf.keras.Model,
temperature: float = 1.0) -> tuple[int, float, float]:
assert temperature > 0
inputs = tf.expand_dims(notes, 0)
predictions = model.predict(inputs)
pitch_logits = predictions['pitch']
step = predictions['step']
duration = predictions['duration']
pitch_logits /= temperature
pitch = tf.random.categorical(pitch_logits, num_samples=1)
pitch = tf.squeeze(pitch, axis=-1)
duration = tf.squeeze(duration, axis=-1)
step = tf.squeeze(step, axis=-1)
step = tf.maximum(0, step)
duration = tf.maximum(0, duration)
return int(pitch), float(step), float(duration)
def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
mse = (y_true - y_pred) ** 2
positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
return tf.reduce_mean(mse + positive_pressure)
def calcular_duracion_midi(archivo_midi):
midi = pretty_midi.PrettyMIDI(archivo_midi)
return midi.get_end_time()
def obtain_statistics(midi_file: str) -> pd.DataFrame:
# Convertir archivo MIDI a DataFrame de notas
notes_df = midi_to_notes(midi_file)
# Calcular estadísticas
mean_pitch = np.mean(notes_df['pitch'])
std_pitch = np.std(notes_df['pitch'])
median_pitch = np.median(notes_df['pitch'])
mean_step = np.mean(notes_df['step'])
std_step = np.std(notes_df['step'])
median_step = np.median(notes_df['step'])
mean_duration = np.mean(notes_df['duration'])
std_duration = np.std(notes_df['duration'])
median_duration = np.median(notes_df['duration'])
# Crear DataFrame con estadísticas
statistics_data = {
'Estadísticas': ['Media', 'Desviación Estandar', 'Mediana'],
'Tono': [mean_pitch, std_pitch, median_pitch],
'Paso': [mean_step, std_step, median_step],
'Duración': [mean_duration, std_duration, median_duration]
}
statistics_df = pd.DataFrame(statistics_data)
return statistics_df
def main():
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
# Rutas de archivos
#sample_file = 'Preludes 2 Through Major keys 39.mid'
st.title('GENERADOR DE MELODIAS CON RNN')
out_file = 'output.mid'
uploaded_file = st.file_uploader("Sube un archivo MIDI")
model=''
pesos=''
with st.container(height = None, border = True):
option = st.selectbox(
"Elige con qué modelo entrenar",
("Maestro", "Lakh")
)
option_musica = st.selectbox(
"Elige instrumento para generar las melodías",
("Piano", "Percusión cromática", "Organo", "Guitarra", "Bajo", "Instrumentos de cuerda", "Conjunto", "Laton",
"Junco", "Pipa", "Instrumento sintetizador", "Pad sintetizador", "Efecto sintetizador", "Etnico", "Percusion", "Efectos de sonido")
)
num_predictions = st.number_input("Ingrese el número de notas:", min_value=100, max_value=150, value=120, step=1)
if uploaded_file and option is not None:
if option=="Maestro":
model="mi_modelo_music.h5"
pesos="mi_pesos_music.h5"
else:
model="mi_modelo03_music.h5"
pesos="mi_pesos03_music.h5"
st.subheader("Archivo cargado")
with st.container(height = None, border = True):
st.write(uploaded_file.name)
# Guardar el archivo en una ubicación temporal
with open(uploaded_file.name, 'wb') as f:
f.write(uploaded_file.getbuffer())
sample_file=uploaded_file.name
# Duracion del MIDI
duracion = calcular_duracion_midi(sample_file)
minutos, segundos = divmod(duracion, 60)
st.write(f"La duración del archivo MIDI subido es: {int(minutos)}:{int(segundos):02d}")
st.subheader("Modelo elegido")
with st.container(height = None, border = True):
st.write(option, f" de tipo instrumental ", option_musica)
# Cargar modelo y pesos
with custom_object_scope({'mse_with_positive_pressure': mse_with_positive_pressure}):
model = keras.models.load_model(model)
model.load_weights(pesos, skip_mismatch=False, by_name=False, options=None)
# Convertir MIDI generado por el modelo a archivo WAV
pm = pretty_midi.PrettyMIDI(sample_file)
instrument_name = ""
if option_musica is not None:
if option_musica=="Piano":
instrument_name="Acoustic Grand Piano"
elif option_musica=="Percusión cromática":
instrument_name="Celesta"
elif option_musica=="Organo":
instrument_name="Hammond Organ"
elif option_musica=="Guitarra":
instrument_name="Acoustic Guitar (nylon)"
elif option_musica=="Bajo":
instrument_name="Acoustic Bass"
elif option_musica=="Instrumentos de cuerda":
instrument_name="Violin"
elif option_musica=="Conjunto":
instrument_name="String Ensemble 1"
elif option_musica=="Laton":
instrument_name="Trumpet"
elif option_musica=="Junco":
instrument_name="Soprano Sax"
elif option_musica=="Pipa":
instrument_name="Piccolo"
elif option_musica=="Instrumento sintetizador":
instrument_name="Lead 2 (sawtooth)"
elif option_musica=="Pad sintetizador":
instrument_name="Pad 2 (warm)"
elif option_musica=="Efecto sintetizador":
instrument_name="FX 2 (soundtrack)"
elif option_musica=="Etnico":
instrument_name="Banjo"
elif option_musica=="Percusion":
instrument_name="Melodic Tom"
elif option_musica=="Efectos de sonido":
instrument_name="Guitar Fret Noise"
else:
instrument_name=pretty_midi.program_to_instrument_name(pm.instruments[0].program)
raw_notes = midi_to_notes(sample_file)
key_order = ['pitch', 'step', 'duration']
seq_length = 25
vocab_size = 128
temperature = 2.0
sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)
input_notes = (sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))
generated_notes = []
prev_start = 0
for _ in range(num_predictions):
pitch, step, duration = predict_next_note(input_notes, model, temperature)
start = prev_start + step
end = start + duration
input_note = (pitch, step, duration)
generated_notes.append((*input_note, start, end))
input_notes = np.delete(input_notes, 0, axis=0)
input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)
prev_start = start
generated_notes = pd.DataFrame(
generated_notes, columns=(*key_order, 'start', 'end'))
notes_to_midi(
generated_notes, out_file=out_file, instrument_name=instrument_name)
# Interfaz de Streamlit
st.title("Generador de notas musicales")
archivo_midi = open(out_file, 'rb').read()
csv_file = obtain_statistics(out_file)
# Guardar el archivo temporalmente
out_file = 'temp.mid'
with open(out_file, 'wb') as f:
f.write(uploaded_file.getbuffer())
# Obtener estadísticas
statistics_df = obtain_statistics(out_file)
# Mostrar estadísticas en Streamlit
st.write("### Estadísticas generadas:")
st.dataframe(statistics_df)
with st.container(height = None, border = True):
st.download_button(
label="Descargar MIDI",
data=archivo_midi,
file_name=out_file, # Nombre del archivo que se descargará
mime='audio/midi'
)
# st.download_button(
# label="Descargar CSV",
# data=statistics_df.to_csv(index=False).encode('utf-8'),
# file_name='statistics.csv',
# mime='text/csv'
# )
# Duracion del MIDI resultante
duracion_f = calcular_duracion_midi(out_file)
minutos_f, segundos_f = divmod(duracion_f, 60)
st.write(f"La duración del archivo MIDI resultante es: {int(minutos_f)}:{int(segundos_f):02d}")
if __name__ == "__main__":
main()