ROOM / scripts /validate_midi.py
solo363614's picture
Upload folder using huggingface_hub
aed1d05 verified
#!/usr/bin/env python3
"""
Validate generated MIDI files for quality.
Checks:
- File integrity (can be parsed)
- Note count and distribution
- Pitch diversity
- Temporal structure
- Velocity patterns
"""
import argparse
import json
from pathlib import Path
from collections import Counter
import numpy as np
import pretty_midi
from tqdm import tqdm
def analyze_midi(midi_path: str) -> dict:
"""Analyze a single MIDI file."""
try:
pm = pretty_midi.PrettyMIDI(midi_path)
except Exception as e:
return {"valid": False, "error": str(e)}
# Collect all notes
all_notes = []
for inst in pm.instruments:
all_notes.extend(inst.notes)
if len(all_notes) == 0:
return {"valid": False, "error": "No notes found"}
# Extract features
pitches = [n.pitch for n in all_notes]
velocities = [n.velocity for n in all_notes]
durations = [n.end - n.start for n in all_notes]
starts = [n.start for n in all_notes]
# Pitch analysis
unique_pitches = len(set(pitches))
pitch_range = max(pitches) - min(pitches)
pitch_mean = np.mean(pitches)
pitch_std = np.std(pitches)
# Velocity analysis
velocity_mean = np.mean(velocities)
velocity_std = np.std(velocities)
# Duration analysis
duration_mean = np.mean(durations)
duration_std = np.std(durations)
# Temporal analysis
total_duration = pm.get_end_time()
note_density = len(all_notes) / total_duration if total_duration > 0 else 0
# Repetition analysis
pitch_counter = Counter(pitches)
most_common_pitch_ratio = pitch_counter.most_common(1)[0][1] / len(pitches)
return {
"valid": True,
"note_count": len(all_notes),
"unique_pitches": unique_pitches,
"pitch_range": pitch_range,
"pitch_mean": round(pitch_mean, 2),
"pitch_std": round(pitch_std, 2),
"velocity_mean": round(velocity_mean, 2),
"velocity_std": round(velocity_std, 2),
"duration_mean": round(duration_mean, 4),
"duration_std": round(duration_std, 4),
"total_duration": round(total_duration, 2),
"note_density": round(note_density, 2),
"most_common_pitch_ratio": round(most_common_pitch_ratio, 4),
"num_instruments": len(pm.instruments),
}
def validate_batch(
midi_dir: str,
output_path: str = None,
min_notes: int = 20,
max_notes: int = 2000,
min_unique_pitches: int = 5,
min_duration: float = 5.0,
max_repetition_ratio: float = 0.5,
):
"""Validate a batch of MIDI files."""
midi_dir = Path(midi_dir)
midi_files = list(midi_dir.rglob("*.mid")) + list(midi_dir.rglob("*.midi"))
print(f"Found {len(midi_files)} MIDI files")
results = {
"total": len(midi_files),
"valid": 0,
"invalid": 0,
"passed_quality": 0,
"failed_quality": 0,
"files": [],
}
quality_failures = Counter()
for midi_path in tqdm(midi_files, desc="Validating"):
analysis = analyze_midi(str(midi_path))
analysis["path"] = str(midi_path)
if not analysis.get("valid"):
results["invalid"] += 1
analysis["quality_passed"] = False
quality_failures["parse_error"] += 1
else:
results["valid"] += 1
# Quality checks
failed = []
if analysis["note_count"] < min_notes:
failed.append("too_few_notes")
if analysis["note_count"] > max_notes:
failed.append("too_many_notes")
if analysis["unique_pitches"] < min_unique_pitches:
failed.append("low_pitch_diversity")
if analysis["total_duration"] < min_duration:
failed.append("too_short")
if analysis["most_common_pitch_ratio"] > max_repetition_ratio:
failed.append("too_repetitive")
if failed:
results["failed_quality"] += 1
analysis["quality_passed"] = False
analysis["quality_failures"] = failed
for f in failed:
quality_failures[f] += 1
else:
results["passed_quality"] += 1
analysis["quality_passed"] = True
results["files"].append(analysis)
# Summary stats
valid_analyses = [f for f in results["files"] if f.get("valid")]
if valid_analyses:
results["summary"] = {
"avg_notes": round(np.mean([f["note_count"] for f in valid_analyses]), 1),
"avg_unique_pitches": round(np.mean([f["unique_pitches"] for f in valid_analyses]), 1),
"avg_duration": round(np.mean([f["total_duration"] for f in valid_analyses]), 1),
"avg_note_density": round(np.mean([f["note_density"] for f in valid_analyses]), 2),
}
results["quality_failure_counts"] = dict(quality_failures)
# Print summary
print("\n" + "="*60)
print("Validation Summary")
print("="*60)
print(f"Total files: {results['total']}")
print(f"Valid (parseable): {results['valid']} ({results['valid']/results['total']*100:.1f}%)")
print(f"Invalid: {results['invalid']}")
print(f"Passed quality: {results['passed_quality']} ({results['passed_quality']/results['total']*100:.1f}%)")
print(f"Failed quality: {results['failed_quality']}")
if valid_analyses:
print(f"\nValid file statistics:")
print(f" Avg notes: {results['summary']['avg_notes']}")
print(f" Avg unique pitches: {results['summary']['avg_unique_pitches']}")
print(f" Avg duration: {results['summary']['avg_duration']}s")
print(f" Avg note density: {results['summary']['avg_note_density']} notes/s")
if quality_failures:
print(f"\nQuality failure breakdown:")
for reason, count in quality_failures.most_common():
print(f" {reason}: {count}")
# Save results
if output_path:
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_path}")
return results
def main():
parser = argparse.ArgumentParser(description="Validate MIDI files")
parser.add_argument(
"midi_dir",
type=str,
help="Directory containing MIDI files",
)
parser.add_argument(
"--output",
type=str,
default="validation_results.json",
help="Output JSON file path",
)
parser.add_argument(
"--min-notes",
type=int,
default=20,
help="Minimum note count",
)
parser.add_argument(
"--max-notes",
type=int,
default=2000,
help="Maximum note count",
)
args = parser.parse_args()
validate_batch(
args.midi_dir,
args.output,
min_notes=args.min_notes,
max_notes=args.max_notes,
)
if __name__ == "__main__":
main()