| """ |
| MuseMorphic MIDI Tokenizer |
| ========================== |
| |
| REMI+ tokenization with BPE compression for MIDI files. |
| Handles: bar boundaries, positions, pitches, velocities, durations, |
| tempo, time signatures, instruments, and control attributes. |
| |
| Based on: |
| - REMI (Huang & Yang, 2020) - Beat-aware positional encoding |
| - REMI+ (von Rütte et al., 2023) - Multi-track extension |
| - MIDI-RWKV (2025) - BPE compression for MIDI tokens |
| - MIDI-GPT (2025) - Attribute control tokens |
| """ |
|
|
| import json |
| import math |
| import os |
| from dataclasses import dataclass, field |
| from typing import List, Dict, Tuple, Optional, Union |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
|
|
| @dataclass |
| class TokenizerConfig: |
| """Configuration for REMI+ tokenizer.""" |
| |
| |
| ticks_per_beat: int = 480 |
| max_bar_length: int = 4 |
| position_resolution: int = 16 |
| |
| |
| pitch_range: Tuple[int, int] = (21, 108) |
| |
| |
| n_velocity_bins: int = 32 |
| |
| |
| n_duration_bins: int = 64 |
| |
| |
| tempo_range: Tuple[int, int] = (30, 210) |
| tempo_step: int = 4 |
| |
| |
| time_signatures: List[Tuple[int, int]] = field(default_factory=lambda: [ |
| (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (3, 8), (6, 8), (12, 8) |
| ]) |
| |
| |
| max_tracks: int = 16 |
| |
| |
| bpe_vocab_size: int = 8192 |
| |
| |
| pad_token: str = "<PAD>" |
| bos_token: str = "<BOS>" |
| eos_token: str = "<EOS>" |
| mask_token: str = "<MASK>" |
| bar_token: str = "<BAR>" |
| track_start_token: str = "<TRACK_START>" |
| track_end_token: str = "<TRACK_END>" |
| phrase_start_token: str = "<PHRASE_START>" |
| phrase_end_token: str = "<PHRASE_END>" |
| |
| |
| ctrl_density_prefix: str = "DENSITY" |
| ctrl_polyphony_prefix: str = "POLY" |
|
|
|
|
| class REMIPlusTokenizer: |
| """ |
| REMI+ tokenizer for MIDI files. |
| |
| Converts MIDI → REMI+ token sequence → integer IDs. |
| Supports phrase-level segmentation for PhraseVAE. |
| |
| Vocabulary structure: |
| [0] PAD |
| [1] BOS |
| [2] EOS |
| [3] MASK |
| [4] BAR |
| [5] TRACK_START |
| [6] TRACK_END |
| [7] PHRASE_START |
| [8] PHRASE_END |
| [9-24] Position_0 to Position_15 |
| [25-31] TimeSig tokens |
| [32-76] Tempo tokens (30-210, step 4) |
| [77-204] Pitch tokens (21-108) |
| [205-236] Velocity tokens (1-32) |
| [237-300] Duration tokens (1-64) |
| [301-428] Program tokens (0-127) |
| [429-438] Density control tokens (1-10) |
| [439-448] Polyphony control tokens (1-10) |
| [449+] BPE merge tokens |
| """ |
| |
| def __init__(self, config: Optional[TokenizerConfig] = None): |
| self.config = config or TokenizerConfig() |
| self._build_vocabulary() |
| |
| def _build_vocabulary(self): |
| """Build the base vocabulary before BPE.""" |
| self.token_to_id = {} |
| self.id_to_token = {} |
| idx = 0 |
| |
| |
| for tok in [self.config.pad_token, self.config.bos_token, |
| self.config.eos_token, self.config.mask_token, |
| self.config.bar_token, self.config.track_start_token, |
| self.config.track_end_token, self.config.phrase_start_token, |
| self.config.phrase_end_token]: |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for p in range(self.config.position_resolution): |
| tok = f"Position_{p}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for num, den in self.config.time_signatures: |
| tok = f"TimeSig_{num}/{den}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for bpm in range(self.config.tempo_range[0], self.config.tempo_range[1] + 1, self.config.tempo_step): |
| tok = f"Tempo_{bpm}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for p in range(self.config.pitch_range[0], self.config.pitch_range[1] + 1): |
| tok = f"Pitch_{p}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for v in range(1, self.config.n_velocity_bins + 1): |
| tok = f"Velocity_{v}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for d in range(1, self.config.n_duration_bins + 1): |
| tok = f"Duration_{d}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for prog in range(128): |
| tok = f"Program_{prog}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for level in range(1, 11): |
| tok = f"{self.config.ctrl_density_prefix}_{level}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| |
| for level in range(1, 11): |
| tok = f"{self.config.ctrl_polyphony_prefix}_{level}" |
| self.token_to_id[tok] = idx |
| self.id_to_token[idx] = tok |
| idx += 1 |
| |
| self.base_vocab_size = idx |
| self.vocab_size = idx |
| |
| |
| self.pad_id = self.token_to_id[self.config.pad_token] |
| self.bos_id = self.token_to_id[self.config.bos_token] |
| self.eos_id = self.token_to_id[self.config.eos_token] |
| self.mask_id = self.token_to_id[self.config.mask_token] |
| self.bar_id = self.token_to_id[self.config.bar_token] |
| |
| def midi_to_remi_tokens(self, notes: List[Dict], tempo: float = 120.0, |
| time_sig: Tuple[int, int] = (4, 4)) -> List[str]: |
| """ |
| Convert a list of note events to REMI+ token strings. |
| |
| Args: |
| notes: List of dicts with keys: pitch, start, duration, velocity, program |
| tempo: BPM |
| time_sig: (numerator, denominator) |
| Returns: |
| List of REMI+ token strings |
| """ |
| if not notes: |
| return [] |
| |
| |
| notes = sorted(notes, key=lambda n: (n.get('start', 0), n.get('pitch', 0))) |
| |
| |
| tpb = self.config.ticks_per_beat |
| beats_per_bar = time_sig[0] * (4.0 / time_sig[1]) |
| ticks_per_bar = int(tpb * beats_per_bar) |
| ticks_per_position = ticks_per_bar // self.config.position_resolution |
| |
| tokens = [] |
| current_bar = -1 |
| |
| |
| tokens.append(self.config.track_start_token) |
| |
| |
| ts_tok = f"TimeSig_{time_sig[0]}/{time_sig[1]}" |
| if ts_tok in self.token_to_id: |
| tokens.append(ts_tok) |
| |
| |
| tempo_bin = round(tempo / self.config.tempo_step) * self.config.tempo_step |
| tempo_bin = max(self.config.tempo_range[0], min(self.config.tempo_range[1], tempo_bin)) |
| tokens.append(f"Tempo_{tempo_bin}") |
| |
| for note in notes: |
| start = note.get('start', 0) |
| pitch = note.get('pitch', 60) |
| duration = note.get('duration', tpb) |
| velocity = note.get('velocity', 80) |
| |
| |
| bar = int(start // ticks_per_bar) |
| |
| |
| while current_bar < bar: |
| current_bar += 1 |
| tokens.append(self.config.bar_token) |
| |
| |
| pos_in_bar = start % ticks_per_bar |
| position = min( |
| int(pos_in_bar / ticks_per_position), |
| self.config.position_resolution - 1 |
| ) |
| tokens.append(f"Position_{position}") |
| |
| |
| pitch = max(self.config.pitch_range[0], min(self.config.pitch_range[1], pitch)) |
| tokens.append(f"Pitch_{pitch}") |
| |
| |
| vel_bin = max(1, min(self.config.n_velocity_bins, |
| int(velocity / 128 * self.config.n_velocity_bins) + 1)) |
| tokens.append(f"Velocity_{vel_bin}") |
| |
| |
| dur_bin = max(1, min(self.config.n_duration_bins, |
| int(duration / ticks_per_position))) |
| tokens.append(f"Duration_{dur_bin}") |
| |
| tokens.append(self.config.track_end_token) |
| return tokens |
| |
| def encode(self, tokens: List[str]) -> List[int]: |
| """Convert token strings to integer IDs.""" |
| ids = [self.bos_id] |
| for tok in tokens: |
| if tok in self.token_to_id: |
| ids.append(self.token_to_id[tok]) |
| ids.append(self.eos_id) |
| return ids |
| |
| def decode(self, ids: List[int]) -> List[str]: |
| """Convert integer IDs to token strings.""" |
| tokens = [] |
| for id_ in ids: |
| if id_ in self.id_to_token: |
| tok = self.id_to_token[id_] |
| if tok not in [self.config.pad_token, self.config.bos_token, self.config.eos_token]: |
| tokens.append(tok) |
| return tokens |
| |
| def segment_into_phrases(self, tokens: List[str], bars_per_phrase: int = 1) -> List[List[str]]: |
| """ |
| Segment a full REMI+ token sequence into phrase-level chunks. |
| |
| Each phrase = one bar of one track (following PhraseVAE convention). |
| """ |
| phrases = [] |
| current_phrase = [] |
| bar_count = 0 |
| |
| for tok in tokens: |
| if tok == self.config.bar_token: |
| bar_count += 1 |
| if bar_count > bars_per_phrase and current_phrase: |
| phrases.append(current_phrase) |
| current_phrase = [] |
| bar_count = 1 |
| current_phrase.append(tok) |
| |
| if current_phrase: |
| phrases.append(current_phrase) |
| |
| return phrases |
| |
| def compute_controls(self, phrase_tokens: List[str]) -> Dict[str, int]: |
| """ |
| Compute control attributes from a phrase's tokens. |
| |
| Controls: |
| - density: number of notes (binned 1-10) |
| - polyphony: max simultaneous notes at any position (binned 1-10) |
| """ |
| note_count = sum(1 for t in phrase_tokens if t.startswith("Pitch_")) |
| |
| |
| density = min(10, max(1, int(note_count / 3) + 1)) |
| |
| |
| positions = {} |
| current_pos = 0 |
| for tok in phrase_tokens: |
| if tok.startswith("Position_"): |
| current_pos = int(tok.split("_")[1]) |
| elif tok.startswith("Pitch_"): |
| positions[current_pos] = positions.get(current_pos, 0) + 1 |
| |
| max_poly = max(positions.values()) if positions else 1 |
| polyphony = min(10, max(1, max_poly)) |
| |
| return {'density': density, 'polyphony': polyphony} |
| |
| def pad_sequence(self, ids: List[int], max_len: int) -> List[int]: |
| """Pad or truncate to max_len.""" |
| if len(ids) >= max_len: |
| return ids[:max_len] |
| return ids + [self.pad_id] * (max_len - len(ids)) |
| |
| def save(self, path: str): |
| """Save tokenizer to directory.""" |
| os.makedirs(path, exist_ok=True) |
| data = { |
| 'token_to_id': self.token_to_id, |
| 'config': { |
| 'ticks_per_beat': self.config.ticks_per_beat, |
| 'position_resolution': self.config.position_resolution, |
| 'pitch_range': list(self.config.pitch_range), |
| 'n_velocity_bins': self.config.n_velocity_bins, |
| 'n_duration_bins': self.config.n_duration_bins, |
| 'tempo_range': list(self.config.tempo_range), |
| 'tempo_step': self.config.tempo_step, |
| 'bpe_vocab_size': self.config.bpe_vocab_size, |
| } |
| } |
| with open(os.path.join(path, 'tokenizer.json'), 'w') as f: |
| json.dump(data, f, indent=2) |
| |
| @classmethod |
| def load(cls, path: str) -> 'REMIPlusTokenizer': |
| """Load tokenizer from directory.""" |
| with open(os.path.join(path, 'tokenizer.json'), 'r') as f: |
| data = json.load(f) |
| |
| config = TokenizerConfig(**data['config']) |
| tokenizer = cls(config) |
| tokenizer.token_to_id = data['token_to_id'] |
| tokenizer.id_to_token = {int(k): v for k, v in |
| {v: k for k, v in data['token_to_id'].items()}.items()} |
| return tokenizer |
| |
| def tokens_to_midi_notes(self, tokens: List[str], ticks_per_beat: int = 480) -> List[Dict]: |
| """ |
| Convert REMI+ tokens back to note events. |
| |
| Returns list of dicts: {pitch, start, duration, velocity} |
| """ |
| notes = [] |
| current_bar = -1 |
| current_position = 0 |
| current_tempo = 120 |
| time_sig = (4, 4) |
| |
| beats_per_bar = 4.0 |
| ticks_per_bar = ticks_per_beat * 4 |
| ticks_per_position = ticks_per_bar // self.config.position_resolution |
| |
| |
| pending_pitch = None |
| pending_velocity = None |
| |
| for tok in tokens: |
| if tok.startswith("TimeSig_"): |
| parts = tok.split("_")[1].split("/") |
| time_sig = (int(parts[0]), int(parts[1])) |
| beats_per_bar = time_sig[0] * (4.0 / time_sig[1]) |
| ticks_per_bar = int(ticks_per_beat * beats_per_bar) |
| ticks_per_position = ticks_per_bar // self.config.position_resolution |
| |
| elif tok.startswith("Tempo_"): |
| current_tempo = int(tok.split("_")[1]) |
| |
| elif tok == self.config.bar_token: |
| current_bar += 1 |
| |
| elif tok.startswith("Position_"): |
| current_position = int(tok.split("_")[1]) |
| |
| elif tok.startswith("Pitch_"): |
| pending_pitch = int(tok.split("_")[1]) |
| |
| elif tok.startswith("Velocity_"): |
| pending_velocity = int(tok.split("_")[1]) |
| |
| elif tok.startswith("Duration_"): |
| if pending_pitch is not None: |
| dur_bin = int(tok.split("_")[1]) |
| start = current_bar * ticks_per_bar + current_position * ticks_per_position |
| duration = dur_bin * ticks_per_position |
| velocity = int((pending_velocity or 16) / self.config.n_velocity_bins * 127) |
| |
| notes.append({ |
| 'pitch': pending_pitch, |
| 'start': max(0, start), |
| 'duration': duration, |
| 'velocity': min(127, max(1, velocity)), |
| }) |
| pending_pitch = None |
| pending_velocity = None |
| |
| return notes |
|
|
|
|
| |
| |
| |
|
|
| def notes_to_midi_file(notes: List[Dict], output_path: str, |
| tempo: float = 120.0, ticks_per_beat: int = 480): |
| """ |
| Write note events to a MIDI file. |
| |
| Uses midiutil for lightweight MIDI writing (no heavy dependencies). |
| """ |
| try: |
| from midiutil import MIDIFile |
| |
| midi = MIDIFile(1, ticks_per_quarternote=ticks_per_beat) |
| midi.addTempo(0, 0, tempo) |
| |
| for note in notes: |
| pitch = note['pitch'] |
| start_beat = note['start'] / ticks_per_beat |
| duration_beat = note['duration'] / ticks_per_beat |
| velocity = note['velocity'] |
| |
| midi.addNote(0, 0, pitch, start_beat, duration_beat, velocity) |
| |
| with open(output_path, 'wb') as f: |
| midi.writeFile(f) |
| |
| return True |
| except ImportError: |
| print("midiutil not installed. Install with: pip install midiutil") |
| return False |
|
|
|
|
| def midi_file_to_notes(midi_path: str) -> Tuple[List[Dict], float, Tuple[int, int]]: |
| """ |
| Read a MIDI file and extract note events. |
| |
| Returns: (notes, tempo, time_signature) |
| """ |
| try: |
| import pretty_midi |
| |
| pm = pretty_midi.PrettyMIDI(midi_path) |
| tempo = pm.estimate_tempo() |
| |
| |
| if pm.time_signature_changes: |
| ts = pm.time_signature_changes[0] |
| time_sig = (ts.numerator, ts.denominator) |
| else: |
| time_sig = (4, 4) |
| |
| notes = [] |
| tpb = 480 |
| |
| for instrument in pm.instruments: |
| if instrument.is_drum: |
| continue |
| for note in instrument.notes: |
| start_ticks = int(note.start * tempo / 60.0 * tpb) |
| duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb) |
| |
| notes.append({ |
| 'pitch': note.pitch, |
| 'start': start_ticks, |
| 'duration': max(1, duration_ticks), |
| 'velocity': note.velocity, |
| 'program': instrument.program, |
| }) |
| |
| return notes, tempo, time_sig |
| |
| except ImportError: |
| print("pretty_midi not installed. Install with: pip install pretty_midi") |
| return [], 120.0, (4, 4) |
|
|