| import os
|
| import sys
|
| import numpy as np
|
| import pretty_midi
|
| import mir_eval
|
| import matplotlib.pyplot as plt
|
| from basic_pitch.inference import predict
|
|
|
|
|
| MODEL_PATH = "CornetAI_SavedModel"
|
| ONSET_TOL = 0.150
|
| CEILING_F1 = 0.858
|
|
|
| def get_friendly_score(f1_value):
|
| score = (f1_value / CEILING_F1) * 10
|
| return min(10.0, round(score, 1))
|
|
|
| def midi_a_espanol_corneta(midi_num):
|
| """Mapeo específico simplificado para corneta."""
|
| n = midi_num % 12
|
| mapping = {7: 'Sol', 8: 'La', 0: 'Do', 1: 'Re', 4: 'Mi'}
|
| nombres_base = ['Do', 'Re', 'Re', 'Re', 'Mi', 'Fa', 'Fa', 'Sol', 'La', 'La', 'Si', 'Si']
|
| return mapping.get(n, nombres_base[n])
|
|
|
| def plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, output_filename):
|
| """Genera la comparativa y la guarda como imagen con el informe incluido."""
|
| fig = plt.figure(figsize=(15, 8))
|
| ax = fig.add_axes([0.1, 0.1, 0.6, 0.8])
|
|
|
|
|
| for i, note in enumerate(pm_ref.instruments[0].notes):
|
| ax.barh(note.pitch, note.end - note.start, left=note.start,
|
| height=0.4, color='green', alpha=0.3,
|
| edgecolor='black', linewidth=1.5,
|
| label="Partitura (Referencia)" if i == 0 else "")
|
|
|
|
|
| for i, note in enumerate(pm_est.instruments[0].notes):
|
| ax.barh(note.pitch, note.end - note.start, left=note.start,
|
| height=0.4, color='blue', alpha=0.5,
|
| label="Tu Ejecución (CornetAI)" if i == 0 else "")
|
|
|
|
|
| all_pitches = set()
|
| for note in pm_ref.instruments[0].notes:
|
| all_pitches.add(note.pitch)
|
| for note in pm_est.instruments[0].notes:
|
| all_pitches.add(note.pitch)
|
|
|
| all_pitches = sorted(all_pitches)
|
| ytick_labels = [midi_a_espanol_corneta(p) for p in all_pitches]
|
|
|
| ax.set_yticks(all_pitches)
|
| ax.set_yticklabels(ytick_labels)
|
|
|
| ax.set_xlabel("Tiempo (segundos)")
|
| ax.set_ylabel("Nota")
|
| ax.set_title(f"Informe de Evaluación CornetAI - {output_filename}")
|
| ax.legend(loc='upper left')
|
| ax.grid(axis='x', linestyle='--', alpha=0.3)
|
|
|
|
|
| info_text = f"PUNTUACIÓN: {puntuacion}/10\n\n"
|
| info_text += "CORRECCIONES:\n"
|
| info_text += fallos_texto
|
|
|
| fig.text(0.72, 0.85, " RESULTADOS", fontsize=16, fontweight='bold', color='darkblue')
|
| fig.text(0.72, 0.5, info_text, fontsize=12, va='center',
|
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
|
|
| img_name = output_filename.replace(".wav", "_resultado.png")
|
| plt.savefig(img_name, dpi=300, bbox_inches='tight')
|
| print(f"✅ Imagen de resultados guardada como: {img_name}")
|
|
|
| plt.show()
|
|
|
| def evaluar_ejecucion(audio_path, midi_gt_path):
|
| if not os.path.exists(audio_path) or not os.path.exists(midi_gt_path):
|
| print("❌ Error: Archivos no encontrados.")
|
| return
|
|
|
| print(f"Analizando interpretación...")
|
|
|
| try:
|
| _, midi_data, _ = predict(audio_path, model_or_model_path=MODEL_PATH)
|
| except Exception as e:
|
| print(f"Error en la inferencia: {e}")
|
| return
|
|
|
| pm_ref = pretty_midi.PrettyMIDI(midi_gt_path)
|
| pm_est = midi_data
|
|
|
| ref_notes = pm_ref.instruments[0].notes
|
| est_notes = pm_est.instruments[0].notes
|
|
|
| if not est_notes:
|
| print("No se han detectado notas.")
|
| return
|
|
|
| ref_int = np.array([[n.start, n.end] for n in ref_notes])
|
| ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in ref_notes])
|
| est_int = np.array([[n.start, n.end] for n in est_notes])
|
| est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in est_notes])
|
|
|
| metrics = mir_eval.transcription.evaluate(
|
| ref_int, ref_pit, est_int, est_pit,
|
| onset_tolerance=ONSET_TOL, offset_ratio=None
|
| )
|
|
|
| puntuacion = get_friendly_score(metrics['F-measure_no_offset'])
|
|
|
|
|
| fallados_idx = []
|
| for i, ref_note in enumerate(ref_notes):
|
| ref_onset = ref_note.start
|
|
|
| tiene_match_temporal = False
|
| for est_note in est_notes:
|
| est_onset = est_note.start
|
| if abs(ref_onset - est_onset) <= ONSET_TOL:
|
| tiene_match_temporal = True
|
| break
|
|
|
| if not tiene_match_temporal:
|
| fallados_idx.append(i)
|
|
|
| fallos_lista = []
|
| print("\n" + "="*45)
|
| print(f"EVALUACIÓN DE CORNETAI")
|
| print("="*45)
|
| print(f">> NOTA FINAL: {puntuacion} / 10 <<")
|
| print("="*45)
|
|
|
| if fallados_idx:
|
| for idx in fallados_idx[:8]:
|
| nota_es = midi_a_espanol_corneta(ref_notes[idx].pitch)
|
| tiempo = round(ref_notes[idx].start, 2)
|
| fallos_lista.append(f"- Revisa {nota_es} ({tiempo}s)")
|
| print(f" - Revisa el {nota_es} en el segundo {tiempo}")
|
|
|
| if len(fallados_idx) > 6:
|
| fallos_lista.append(f"... y {len(fallados_idx)-6} más.")
|
| else:
|
| fallos_lista.append("¡Interpretación Perfecta!")
|
|
|
| fallos_texto = "\n".join(fallos_lista)
|
|
|
| plot_piano_rolls(pm_ref, pm_est, puntuacion, fallos_texto, os.path.basename(audio_path))
|
|
|
| if __name__ == "__main__":
|
| if len(sys.argv) == 3:
|
| evaluar_ejecucion(sys.argv[1], sys.argv[2])
|
| else:
|
| print("Uso: python evaluador_individual.py <audio.wav> <referencia.mid>") |