asdf98 commited on
Commit
09ff16c
·
verified ·
1 Parent(s): f88bd6f

Upload musemorphic/tokenizer.py

Browse files
Files changed (1) hide show
  1. musemorphic/tokenizer.py +521 -0
musemorphic/tokenizer.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseMorphic MIDI Tokenizer
3
+ ==========================
4
+
5
+ REMI+ tokenization with BPE compression for MIDI files.
6
+ Handles: bar boundaries, positions, pitches, velocities, durations,
7
+ tempo, time signatures, instruments, and control attributes.
8
+
9
+ Based on:
10
+ - REMI (Huang & Yang, 2020) - Beat-aware positional encoding
11
+ - REMI+ (von Rütte et al., 2023) - Multi-track extension
12
+ - MIDI-RWKV (2025) - BPE compression for MIDI tokens
13
+ - MIDI-GPT (2025) - Attribute control tokens
14
+ """
15
+
16
+ import json
17
+ import math
18
+ import os
19
+ from dataclasses import dataclass, field
20
+ from typing import List, Dict, Tuple, Optional, Union
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+
26
+ @dataclass
27
+ class TokenizerConfig:
28
+ """Configuration for REMI+ tokenizer."""
29
+
30
+ # Resolution
31
+ ticks_per_beat: int = 480
32
+ max_bar_length: int = 4 # Max bars per phrase
33
+ position_resolution: int = 16 # 16th note grid
34
+
35
+ # Pitch
36
+ pitch_range: Tuple[int, int] = (21, 108) # Piano range (A0-C8)
37
+
38
+ # Velocity
39
+ n_velocity_bins: int = 32
40
+
41
+ # Duration
42
+ n_duration_bins: int = 64 # From 1/32 note to 8 whole notes
43
+
44
+ # Tempo
45
+ tempo_range: Tuple[int, int] = (30, 210)
46
+ tempo_step: int = 4
47
+
48
+ # Time Signature
49
+ time_signatures: List[Tuple[int, int]] = field(default_factory=lambda: [
50
+ (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (3, 8), (6, 8), (12, 8)
51
+ ])
52
+
53
+ # Instruments (General MIDI programs to track)
54
+ max_tracks: int = 16
55
+
56
+ # BPE
57
+ bpe_vocab_size: int = 8192
58
+
59
+ # Special tokens
60
+ pad_token: str = "<PAD>"
61
+ bos_token: str = "<BOS>"
62
+ eos_token: str = "<EOS>"
63
+ mask_token: str = "<MASK>"
64
+ bar_token: str = "<BAR>"
65
+ track_start_token: str = "<TRACK_START>"
66
+ track_end_token: str = "<TRACK_END>"
67
+ phrase_start_token: str = "<PHRASE_START>"
68
+ phrase_end_token: str = "<PHRASE_END>"
69
+
70
+ # Control tokens
71
+ ctrl_density_prefix: str = "DENSITY"
72
+ ctrl_polyphony_prefix: str = "POLY"
73
+
74
+
75
+ class REMIPlusTokenizer:
76
+ """
77
+ REMI+ tokenizer for MIDI files.
78
+
79
+ Converts MIDI → REMI+ token sequence → integer IDs.
80
+ Supports phrase-level segmentation for PhraseVAE.
81
+
82
+ Vocabulary structure:
83
+ [0] PAD
84
+ [1] BOS
85
+ [2] EOS
86
+ [3] MASK
87
+ [4] BAR
88
+ [5] TRACK_START
89
+ [6] TRACK_END
90
+ [7] PHRASE_START
91
+ [8] PHRASE_END
92
+ [9-24] Position_0 to Position_15
93
+ [25-31] TimeSig tokens
94
+ [32-76] Tempo tokens (30-210, step 4)
95
+ [77-204] Pitch tokens (21-108)
96
+ [205-236] Velocity tokens (1-32)
97
+ [237-300] Duration tokens (1-64)
98
+ [301-428] Program tokens (0-127)
99
+ [429-438] Density control tokens (1-10)
100
+ [439-448] Polyphony control tokens (1-10)
101
+ [449+] BPE merge tokens
102
+ """
103
+
104
+ def __init__(self, config: Optional[TokenizerConfig] = None):
105
+ self.config = config or TokenizerConfig()
106
+ self._build_vocabulary()
107
+
108
+ def _build_vocabulary(self):
109
+ """Build the base vocabulary before BPE."""
110
+ self.token_to_id = {}
111
+ self.id_to_token = {}
112
+ idx = 0
113
+
114
+ # Special tokens
115
+ for tok in [self.config.pad_token, self.config.bos_token,
116
+ self.config.eos_token, self.config.mask_token,
117
+ self.config.bar_token, self.config.track_start_token,
118
+ self.config.track_end_token, self.config.phrase_start_token,
119
+ self.config.phrase_end_token]:
120
+ self.token_to_id[tok] = idx
121
+ self.id_to_token[idx] = tok
122
+ idx += 1
123
+
124
+ # Position tokens
125
+ for p in range(self.config.position_resolution):
126
+ tok = f"Position_{p}"
127
+ self.token_to_id[tok] = idx
128
+ self.id_to_token[idx] = tok
129
+ idx += 1
130
+
131
+ # Time signature tokens
132
+ for num, den in self.config.time_signatures:
133
+ tok = f"TimeSig_{num}/{den}"
134
+ self.token_to_id[tok] = idx
135
+ self.id_to_token[idx] = tok
136
+ idx += 1
137
+
138
+ # Tempo tokens
139
+ for bpm in range(self.config.tempo_range[0], self.config.tempo_range[1] + 1, self.config.tempo_step):
140
+ tok = f"Tempo_{bpm}"
141
+ self.token_to_id[tok] = idx
142
+ self.id_to_token[idx] = tok
143
+ idx += 1
144
+
145
+ # Pitch tokens
146
+ for p in range(self.config.pitch_range[0], self.config.pitch_range[1] + 1):
147
+ tok = f"Pitch_{p}"
148
+ self.token_to_id[tok] = idx
149
+ self.id_to_token[idx] = tok
150
+ idx += 1
151
+
152
+ # Velocity tokens
153
+ for v in range(1, self.config.n_velocity_bins + 1):
154
+ tok = f"Velocity_{v}"
155
+ self.token_to_id[tok] = idx
156
+ self.id_to_token[idx] = tok
157
+ idx += 1
158
+
159
+ # Duration tokens
160
+ for d in range(1, self.config.n_duration_bins + 1):
161
+ tok = f"Duration_{d}"
162
+ self.token_to_id[tok] = idx
163
+ self.id_to_token[idx] = tok
164
+ idx += 1
165
+
166
+ # Program tokens (GM instruments)
167
+ for prog in range(128):
168
+ tok = f"Program_{prog}"
169
+ self.token_to_id[tok] = idx
170
+ self.id_to_token[idx] = tok
171
+ idx += 1
172
+
173
+ # Control tokens: density levels
174
+ for level in range(1, 11):
175
+ tok = f"{self.config.ctrl_density_prefix}_{level}"
176
+ self.token_to_id[tok] = idx
177
+ self.id_to_token[idx] = tok
178
+ idx += 1
179
+
180
+ # Control tokens: polyphony levels
181
+ for level in range(1, 11):
182
+ tok = f"{self.config.ctrl_polyphony_prefix}_{level}"
183
+ self.token_to_id[tok] = idx
184
+ self.id_to_token[idx] = tok
185
+ idx += 1
186
+
187
+ self.base_vocab_size = idx
188
+ self.vocab_size = idx # Will grow with BPE
189
+
190
+ # Special token IDs
191
+ self.pad_id = self.token_to_id[self.config.pad_token]
192
+ self.bos_id = self.token_to_id[self.config.bos_token]
193
+ self.eos_id = self.token_to_id[self.config.eos_token]
194
+ self.mask_id = self.token_to_id[self.config.mask_token]
195
+ self.bar_id = self.token_to_id[self.config.bar_token]
196
+
197
+ def midi_to_remi_tokens(self, notes: List[Dict], tempo: float = 120.0,
198
+ time_sig: Tuple[int, int] = (4, 4)) -> List[str]:
199
+ """
200
+ Convert a list of note events to REMI+ token strings.
201
+
202
+ Args:
203
+ notes: List of dicts with keys: pitch, start, duration, velocity, program
204
+ tempo: BPM
205
+ time_sig: (numerator, denominator)
206
+ Returns:
207
+ List of REMI+ token strings
208
+ """
209
+ if not notes:
210
+ return []
211
+
212
+ # Sort by start time
213
+ notes = sorted(notes, key=lambda n: (n.get('start', 0), n.get('pitch', 0)))
214
+
215
+ # Compute bar and position info
216
+ tpb = self.config.ticks_per_beat
217
+ beats_per_bar = time_sig[0] * (4.0 / time_sig[1])
218
+ ticks_per_bar = int(tpb * beats_per_bar)
219
+ ticks_per_position = ticks_per_bar // self.config.position_resolution
220
+
221
+ tokens = []
222
+ current_bar = -1
223
+
224
+ # Add track metadata
225
+ tokens.append(self.config.track_start_token)
226
+
227
+ # Time signature
228
+ ts_tok = f"TimeSig_{time_sig[0]}/{time_sig[1]}"
229
+ if ts_tok in self.token_to_id:
230
+ tokens.append(ts_tok)
231
+
232
+ # Tempo
233
+ tempo_bin = round(tempo / self.config.tempo_step) * self.config.tempo_step
234
+ tempo_bin = max(self.config.tempo_range[0], min(self.config.tempo_range[1], tempo_bin))
235
+ tokens.append(f"Tempo_{tempo_bin}")
236
+
237
+ for note in notes:
238
+ start = note.get('start', 0)
239
+ pitch = note.get('pitch', 60)
240
+ duration = note.get('duration', tpb)
241
+ velocity = note.get('velocity', 80)
242
+
243
+ # Compute bar number
244
+ bar = int(start // ticks_per_bar)
245
+
246
+ # Add bar tokens for any new bars
247
+ while current_bar < bar:
248
+ current_bar += 1
249
+ tokens.append(self.config.bar_token)
250
+
251
+ # Position within bar (quantized to grid)
252
+ pos_in_bar = start % ticks_per_bar
253
+ position = min(
254
+ int(pos_in_bar / ticks_per_position),
255
+ self.config.position_resolution - 1
256
+ )
257
+ tokens.append(f"Position_{position}")
258
+
259
+ # Pitch
260
+ pitch = max(self.config.pitch_range[0], min(self.config.pitch_range[1], pitch))
261
+ tokens.append(f"Pitch_{pitch}")
262
+
263
+ # Velocity (bin)
264
+ vel_bin = max(1, min(self.config.n_velocity_bins,
265
+ int(velocity / 128 * self.config.n_velocity_bins) + 1))
266
+ tokens.append(f"Velocity_{vel_bin}")
267
+
268
+ # Duration (bin)
269
+ dur_bin = max(1, min(self.config.n_duration_bins,
270
+ int(duration / ticks_per_position)))
271
+ tokens.append(f"Duration_{dur_bin}")
272
+
273
+ tokens.append(self.config.track_end_token)
274
+ return tokens
275
+
276
+ def encode(self, tokens: List[str]) -> List[int]:
277
+ """Convert token strings to integer IDs."""
278
+ ids = [self.bos_id]
279
+ for tok in tokens:
280
+ if tok in self.token_to_id:
281
+ ids.append(self.token_to_id[tok])
282
+ ids.append(self.eos_id)
283
+ return ids
284
+
285
+ def decode(self, ids: List[int]) -> List[str]:
286
+ """Convert integer IDs to token strings."""
287
+ tokens = []
288
+ for id_ in ids:
289
+ if id_ in self.id_to_token:
290
+ tok = self.id_to_token[id_]
291
+ if tok not in [self.config.pad_token, self.config.bos_token, self.config.eos_token]:
292
+ tokens.append(tok)
293
+ return tokens
294
+
295
+ def segment_into_phrases(self, tokens: List[str], bars_per_phrase: int = 1) -> List[List[str]]:
296
+ """
297
+ Segment a full REMI+ token sequence into phrase-level chunks.
298
+
299
+ Each phrase = one bar of one track (following PhraseVAE convention).
300
+ """
301
+ phrases = []
302
+ current_phrase = []
303
+ bar_count = 0
304
+
305
+ for tok in tokens:
306
+ if tok == self.config.bar_token:
307
+ bar_count += 1
308
+ if bar_count > bars_per_phrase and current_phrase:
309
+ phrases.append(current_phrase)
310
+ current_phrase = []
311
+ bar_count = 1
312
+ current_phrase.append(tok)
313
+
314
+ if current_phrase:
315
+ phrases.append(current_phrase)
316
+
317
+ return phrases
318
+
319
+ def compute_controls(self, phrase_tokens: List[str]) -> Dict[str, int]:
320
+ """
321
+ Compute control attributes from a phrase's tokens.
322
+
323
+ Controls:
324
+ - density: number of notes (binned 1-10)
325
+ - polyphony: max simultaneous notes at any position (binned 1-10)
326
+ """
327
+ note_count = sum(1 for t in phrase_tokens if t.startswith("Pitch_"))
328
+
329
+ # Density: bin into 1-10
330
+ density = min(10, max(1, int(note_count / 3) + 1))
331
+
332
+ # Polyphony: count notes at same position
333
+ positions = {}
334
+ current_pos = 0
335
+ for tok in phrase_tokens:
336
+ if tok.startswith("Position_"):
337
+ current_pos = int(tok.split("_")[1])
338
+ elif tok.startswith("Pitch_"):
339
+ positions[current_pos] = positions.get(current_pos, 0) + 1
340
+
341
+ max_poly = max(positions.values()) if positions else 1
342
+ polyphony = min(10, max(1, max_poly))
343
+
344
+ return {'density': density, 'polyphony': polyphony}
345
+
346
+ def pad_sequence(self, ids: List[int], max_len: int) -> List[int]:
347
+ """Pad or truncate to max_len."""
348
+ if len(ids) >= max_len:
349
+ return ids[:max_len]
350
+ return ids + [self.pad_id] * (max_len - len(ids))
351
+
352
+ def save(self, path: str):
353
+ """Save tokenizer to directory."""
354
+ os.makedirs(path, exist_ok=True)
355
+ data = {
356
+ 'token_to_id': self.token_to_id,
357
+ 'config': {
358
+ 'ticks_per_beat': self.config.ticks_per_beat,
359
+ 'position_resolution': self.config.position_resolution,
360
+ 'pitch_range': list(self.config.pitch_range),
361
+ 'n_velocity_bins': self.config.n_velocity_bins,
362
+ 'n_duration_bins': self.config.n_duration_bins,
363
+ 'tempo_range': list(self.config.tempo_range),
364
+ 'tempo_step': self.config.tempo_step,
365
+ 'bpe_vocab_size': self.config.bpe_vocab_size,
366
+ }
367
+ }
368
+ with open(os.path.join(path, 'tokenizer.json'), 'w') as f:
369
+ json.dump(data, f, indent=2)
370
+
371
+ @classmethod
372
+ def load(cls, path: str) -> 'REMIPlusTokenizer':
373
+ """Load tokenizer from directory."""
374
+ with open(os.path.join(path, 'tokenizer.json'), 'r') as f:
375
+ data = json.load(f)
376
+
377
+ config = TokenizerConfig(**data['config'])
378
+ tokenizer = cls(config)
379
+ tokenizer.token_to_id = data['token_to_id']
380
+ tokenizer.id_to_token = {int(k): v for k, v in
381
+ {v: k for k, v in data['token_to_id'].items()}.items()}
382
+ return tokenizer
383
+
384
+ def tokens_to_midi_notes(self, tokens: List[str], ticks_per_beat: int = 480) -> List[Dict]:
385
+ """
386
+ Convert REMI+ tokens back to note events.
387
+
388
+ Returns list of dicts: {pitch, start, duration, velocity}
389
+ """
390
+ notes = []
391
+ current_bar = -1
392
+ current_position = 0
393
+ current_tempo = 120
394
+ time_sig = (4, 4)
395
+
396
+ beats_per_bar = 4.0 # default
397
+ ticks_per_bar = ticks_per_beat * 4
398
+ ticks_per_position = ticks_per_bar // self.config.position_resolution
399
+
400
+ # Pending note attributes
401
+ pending_pitch = None
402
+ pending_velocity = None
403
+
404
+ for tok in tokens:
405
+ if tok.startswith("TimeSig_"):
406
+ parts = tok.split("_")[1].split("/")
407
+ time_sig = (int(parts[0]), int(parts[1]))
408
+ beats_per_bar = time_sig[0] * (4.0 / time_sig[1])
409
+ ticks_per_bar = int(ticks_per_beat * beats_per_bar)
410
+ ticks_per_position = ticks_per_bar // self.config.position_resolution
411
+
412
+ elif tok.startswith("Tempo_"):
413
+ current_tempo = int(tok.split("_")[1])
414
+
415
+ elif tok == self.config.bar_token:
416
+ current_bar += 1
417
+
418
+ elif tok.startswith("Position_"):
419
+ current_position = int(tok.split("_")[1])
420
+
421
+ elif tok.startswith("Pitch_"):
422
+ pending_pitch = int(tok.split("_")[1])
423
+
424
+ elif tok.startswith("Velocity_"):
425
+ pending_velocity = int(tok.split("_")[1])
426
+
427
+ elif tok.startswith("Duration_"):
428
+ if pending_pitch is not None:
429
+ dur_bin = int(tok.split("_")[1])
430
+ start = current_bar * ticks_per_bar + current_position * ticks_per_position
431
+ duration = dur_bin * ticks_per_position
432
+ velocity = int((pending_velocity or 16) / self.config.n_velocity_bins * 127)
433
+
434
+ notes.append({
435
+ 'pitch': pending_pitch,
436
+ 'start': max(0, start),
437
+ 'duration': duration,
438
+ 'velocity': min(127, max(1, velocity)),
439
+ })
440
+ pending_pitch = None
441
+ pending_velocity = None
442
+
443
+ return notes
444
+
445
+
446
+ # ============================================================================
447
+ # MIDI File I/O Utilities (using midiutil or pretty_midi)
448
+ # ============================================================================
449
+
450
+ def notes_to_midi_file(notes: List[Dict], output_path: str,
451
+ tempo: float = 120.0, ticks_per_beat: int = 480):
452
+ """
453
+ Write note events to a MIDI file.
454
+
455
+ Uses midiutil for lightweight MIDI writing (no heavy dependencies).
456
+ """
457
+ try:
458
+ from midiutil import MIDIFile
459
+
460
+ midi = MIDIFile(1, ticks_per_quarternote=ticks_per_beat)
461
+ midi.addTempo(0, 0, tempo)
462
+
463
+ for note in notes:
464
+ pitch = note['pitch']
465
+ start_beat = note['start'] / ticks_per_beat
466
+ duration_beat = note['duration'] / ticks_per_beat
467
+ velocity = note['velocity']
468
+
469
+ midi.addNote(0, 0, pitch, start_beat, duration_beat, velocity)
470
+
471
+ with open(output_path, 'wb') as f:
472
+ midi.writeFile(f)
473
+
474
+ return True
475
+ except ImportError:
476
+ print("midiutil not installed. Install with: pip install midiutil")
477
+ return False
478
+
479
+
480
+ def midi_file_to_notes(midi_path: str) -> Tuple[List[Dict], float, Tuple[int, int]]:
481
+ """
482
+ Read a MIDI file and extract note events.
483
+
484
+ Returns: (notes, tempo, time_signature)
485
+ """
486
+ try:
487
+ import pretty_midi
488
+
489
+ pm = pretty_midi.PrettyMIDI(midi_path)
490
+ tempo = pm.estimate_tempo()
491
+
492
+ # Get time signature
493
+ if pm.time_signature_changes:
494
+ ts = pm.time_signature_changes[0]
495
+ time_sig = (ts.numerator, ts.denominator)
496
+ else:
497
+ time_sig = (4, 4)
498
+
499
+ notes = []
500
+ tpb = 480 # Standard ticks per beat
501
+
502
+ for instrument in pm.instruments:
503
+ if instrument.is_drum:
504
+ continue
505
+ for note in instrument.notes:
506
+ start_ticks = int(note.start * tempo / 60.0 * tpb)
507
+ duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb)
508
+
509
+ notes.append({
510
+ 'pitch': note.pitch,
511
+ 'start': start_ticks,
512
+ 'duration': max(1, duration_ticks),
513
+ 'velocity': note.velocity,
514
+ 'program': instrument.program,
515
+ })
516
+
517
+ return notes, tempo, time_sig
518
+
519
+ except ImportError:
520
+ print("pretty_midi not installed. Install with: pip install pretty_midi")
521
+ return [], 120.0, (4, 4)