File size: 26,180 Bytes
aa189b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
import json
import random
from dataclasses import dataclass
from functools import lru_cache
from math import ceil, floor, log
from typing import Dict, Iterator, List, Optional, Tuple

import mido


@dataclass
class VocabConfig:
    # Number of note events. Should be 128.
    note_events: int
    # Number of wait events. Configurable, must evenly divide max_wait_time.
    wait_events: int
    # Max wait time in milliseconds to be represented by a single token.
    max_wait_time: int
    # Number of velocity events. Should be 128 (or 100? need to check midi standard)
    velocity_events: int
    # Number of bins to quantize velocity into. Should evenly divide velocity_events.
    velocity_bins: int
    # Exponential scaling factor for velocity bin sizes. 1.0 = linear scaling.
    velocity_exp: float
    # Whether to sort tokens by instrument, note. This should improve data reducibility.
    do_token_sorting: bool
    # Whether tokens should be represented as combined instrument/note/velocity tokens, or separate tokens for each.
    unrolled_tokens: bool
    # If non-zero, notes held for this many seconds will be automatically released during str->midi decoding.
    decode_end_held_note_delay: float
    # If true, repeated notes will be automatically released before playing again during str->midi decoding.
    decode_fix_repeated_notes: bool
    # List of instrument names to use for binning. Must have at most 16 values.
    bin_instrument_names: List[str]
    # Indicates which bin name represents percussion instruments on MIDI channel 10.
    ch10_instrument_bin_name: str
    # Mapping from instrument name to bin name.
    program_name_to_bin_name: Dict[str, str]
    # Mapping from bin name to program name.
    bin_name_to_program_name: Dict[str, str]
    # Mapping from program number to instrument name.
    instrument_names: Dict[str, str]
    # Manual override for velocity bins. Each element is the max velocity value for that bin by index.
    velocity_bins_override: Optional[List[int]] = None

    def __post_init__(self):
        self.validate()
        
        self._instrument_names_str_to_int = {name: int(i) for i, name in self.instrument_names.items()}
        self._instrument_names_int_to_str = {int(i): name for i, name in self.instrument_names.items()}
        
        self._bin_str_to_int = {name: int(i) for i, name in enumerate(self.bin_instrument_names)}

        self._bin_int_to_instrument_int = [self._instrument_names_str_to_int[self.bin_name_to_program_name[name]] if name != self.ch10_instrument_bin_name else 0 for name in self.bin_instrument_names]
        self._instrument_int_to_bin_int = [self._bin_str_to_int[self.program_name_to_bin_name[instr]] if self.program_name_to_bin_name[instr] != "" else -1 for instr in self.program_name_to_bin_name.keys()]

        self._ch10_bin_int = self._bin_str_to_int[self.ch10_instrument_bin_name] if self.ch10_instrument_bin_name else -1

        self.short_instr_bin_names = []
        for instr in self.bin_instrument_names:
            i = min(1, len(instr))
            while instr[:i] in self.short_instr_bin_names:
                i += 1
            self.short_instr_bin_names.append(instr[:i])
        self._short_instrument_names_str_to_int = {name: int(i) for i, name in enumerate(self.short_instr_bin_names)}

        range_excluding_ch10 = [(i if i < 9 else i+1) for i in range(len(self.bin_instrument_names))]
        bins_excluding_ch10 = [n for n in self.bin_instrument_names if n != self.ch10_instrument_bin_name]
        self.bin_channel_map = {bin: channel for channel, bin in zip(range_excluding_ch10, bins_excluding_ch10)}
        if self.ch10_instrument_bin_name:
            self.bin_channel_map[self.ch10_instrument_bin_name] = 9

    def validate(self):
        if self.max_wait_time % self.wait_events != 0:
            raise ValueError("max_wait_time must be exactly divisible by wait_events")
        if self.velocity_bins < 2:
            raise ValueError("velocity_bins must be at least 2")
        if len(self.bin_instrument_names) > 16:
            raise ValueError("bin_instruments must have at most 16 values")
        if self.velocity_bins_override:
            print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
            if len(self.velocity_bins_override) != self.velocity_bins:
                raise ValueError("velocity_bins_override must have same length as velocity_bins")
        if self.ch10_instrument_bin_name and self.ch10_instrument_bin_name not in self.bin_instrument_names:
            raise ValueError("ch10_instrument_bin_name must be in bin_instruments")
        if self.velocity_exp <= 0:
            raise ValueError("velocity_exp must be greater than 0")

    @classmethod
    def from_json(cls, path: str):
        with open(path, "r") as f:
            config = json.load(f)
        return cls(**config)


class VocabUtils:
    def __init__(self, cfg: VocabConfig) -> None:
        self.cfg = cfg
    
    @lru_cache(maxsize=128)
    def format_wait_token(self, wait: int) -> str:
        return f"t{wait}"

    @lru_cache(maxsize=128)
    def format_note_token(self, instrument_bin: int, note: int, velocity_bin: int) -> str:
        return f"{self.cfg.short_instr_bin_names[instrument_bin]}:{note:x}:{velocity_bin:x}"

    def format_unrolled_note(self, note: int) -> str:
        return f"n{note:x}"
    
    def format_unrolled_velocity(self, velocity_bin: int) -> str:
        return f"v{velocity_bin:x}"
    
    def format_unrolled_instrument_bin(self, instrument_bin: int) -> str:
        return f"i{self.cfg.short_instr_bin_names[instrument_bin]}"

    def velocity_to_bin(self, velocity: float) -> int:
        velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
        if self.cfg.velocity_bins_override:
            for i, v in enumerate(self.cfg.velocity_bins_override):
                if velocity <= v:
                    return i
            return 0
        binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
        if self.cfg.velocity_exp == 1.0:
            return ceil(velocity / binsize)
        else:
            return ceil((self.cfg.velocity_events*((self.cfg.velocity_exp**(velocity/self.cfg.velocity_events)-1.0) / (self.cfg.velocity_exp-1.0))) / binsize)

    def bin_to_velocity(self, bin: int) -> int:
        if self.cfg.velocity_bins_override:
            return self.cfg.velocity_bins_override[bin]
        binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
        if self.cfg.velocity_exp == 1.0:
            return max(0, ceil(bin * binsize - 1))
        else:
            return max(0, ceil(self.cfg.velocity_events*log(((self.cfg.velocity_exp-1)*binsize*bin)/self.cfg.velocity_events+1, self.cfg.velocity_exp) - 1))

    def delta_to_wait_ids(self, delta_ms: float) -> Iterator[int]:
        def roundi(f: float):
            return ceil(f - 0.5)

        max_wait_ms = self.cfg.max_wait_time
        div = max_wait_ms / self.cfg.wait_events

        #if delta_ms // max_wait_ms > 512:  # arbitrary limit to avoid excessive time_shifts
        #    raise ValueError("delta_time is too large")
        if delta_ms > max_wait_ms * 10:
            delta_ms = max_wait_ms * 10  # truncate time

        for _ in range(floor(delta_ms / max_wait_ms)):
            yield roundi(max_wait_ms / div)
        leftover_time_shift = roundi((delta_ms % max_wait_ms) / div)
        if leftover_time_shift > 0:
            yield leftover_time_shift

    def prog_data_to_token_data(self, program: int, channel: int, note: int, velocity: float) -> Optional[Tuple[int, int, int]]:
        if channel == 9:
            if self.cfg._ch10_bin_int == -1:
                return None
            return self.cfg._ch10_bin_int, note, self.velocity_to_bin(velocity)
        
        instrument_bin = self.cfg._instrument_int_to_bin_int[program]
        if instrument_bin != -1:
            return instrument_bin, note, self.velocity_to_bin(velocity)
        return None

    def prog_data_list_to_token_data_list(self, data: List[Tuple[int, int, int, float]]) -> Iterator[Tuple[int, int, int]]:
        for d in data:
            token_data = self.prog_data_to_token_data(*d)
            if token_data is not None:
                yield token_data

    def sort_token_data(self, data: List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]:
        # ensure order is preserved for tokens with the same instrument, note
        data = [(i, n, v, x) for x, (i, n, v) in enumerate(data)]
        data.sort(key=lambda x: (x[0]!=self.cfg._ch10_bin_int, x[0], x[1], x[3]))
        return [(i, n, v) for i, n, v, _ in data]

    def data_to_wait_tokens(self, delta_ms: float) -> List[str]:
        if delta_ms == 0.0:
            return []
        return [self.format_wait_token(i) for i in self.delta_to_wait_ids(delta_ms)]
    
    def wait_token_to_delta(self, token: str) -> float:
        return self.cfg.max_wait_time / self.cfg.wait_events * int(token[1:])
    
    def note_token_to_data(self, token: str) -> Tuple[int, int, int]:
        instr_str, note_str, velocity_str = token.strip().split(":")
        instr_bin = self.cfg._short_instrument_names_str_to_int[instr_str]
        note = int(note_str, base=16)
        velocity = self.bin_to_velocity(int(velocity_str, base=16))
        return instr_bin, note, velocity


@dataclass
class AugmentValues:
    instrument_bin_remap: Dict[int, int]
    velocity_mod_factor: float
    transpose_semitones: int
    time_stretch_factor: float

    @classmethod
    def default(cls) -> "AugmentValues":
        return cls(
            instrument_bin_remap={},
            velocity_mod_factor=1.0,
            transpose_semitones=0,
            time_stretch_factor=1.0,
        )


@dataclass
class AugmentConfig:
    # The number of times to augment each MIDI file. The dataset size will be multiplied by this number.
    augment_data_factor: int
    # A list of instrument names to randomly swap with each other.
    instrument_mixups: List[List[str]]
    # A list of percentages to change the note velocity by. 0.0 = no change. 0 is included by default.
    velocity_mod_pct: List[float]
    # A list of semitones to transpose by. 0 is included by default.
    transpose_semitones: List[int]
    # A list of percentages to stretch the tempo by. 0.0 = no stretch. 0 is included by default.
    time_stretch_pct: List[float]
    # Random seed to use for reproducibility.
    seed: int

    cfg: VocabConfig

    def __post_init__(self):
        self.validate()
        if len(self.velocity_mod_pct) == 0:
            self.velocity_mod_pct = [0.0]
        if len(self.transpose_semitones) == 0:
            self.transpose_semitones = [0]
        if len(self.time_stretch_pct) == 0:
            self.time_stretch_pct = [0.0]
        
        self._instrument_mixups_int = [[self.cfg._bin_str_to_int[i] for i in l if i in self.cfg._bin_str_to_int] for l in self.instrument_mixups]
        self._instrument_mixups_int = [l for l in self._instrument_mixups_int if len(l) > 0]  # remove empty lists
        self._instrument_pool_assignments = {}
        self._mixup_pools = []
        for pool_i, mixup_list in enumerate(self._instrument_mixups_int):
            pool = set()
            for i in mixup_list:
                pool.add(i)
                self._instrument_pool_assignments[i] = pool_i
            self._mixup_pools.append(pool)


    def validate(self):
        if self.augment_data_factor < 1:
            raise ValueError("augment_data_factor must be at least 1")
        used_instruments = set()
        for mixup_list in self.instrument_mixups:
            for n in mixup_list:
                if n in used_instruments:
                    raise ValueError(f"Duplicate instrument name: {n}")
                used_instruments.add(n)

    @classmethod
    def from_json(cls, path: str, cfg: VocabConfig):
        with open(path, "r") as f:
            config = json.load(f)
        config["cfg"] = cfg
        if "seed" not in config:
            config["seed"] = random.randint(0, 2**32 - 1)
        return cls(**config)
    
    def get_augment_values(self, filename: str) -> Iterator[AugmentValues]:
        # first yield default values
        yield AugmentValues.default()

        rng = random.Random(self.seed + hash(filename))
        for _ in range(int(self.augment_data_factor - 1)):
            # randomize order for each pool
            randomized_pools = [list(pool) for pool in self._mixup_pools]
            for pool in randomized_pools:
                rng.shuffle(pool)
            # distribute reassignments
            instrument_bin_remap = {}
            for i, pool in enumerate(randomized_pools):
                for j, instrument in enumerate(pool):
                    instrument_bin_remap[instrument] = randomized_pools[i - 1][j]
            yield AugmentValues(
                instrument_bin_remap=instrument_bin_remap,
                velocity_mod_factor=1.0 + rng.choice(self.velocity_mod_pct),
                transpose_semitones=rng.choice(self.transpose_semitones),
                time_stretch_factor=1.0 + rng.choice(self.time_stretch_pct),
            )


@dataclass
class FilterConfig:
    # Whether to filter out MIDI files with duplicate MD5 hashes.
    deduplicate_md5: bool
    # Minimum time delay between notes in a file before splitting into multiple documents.
    piece_split_delay: float
    # Minimum length of a piece in milliseconds.
    min_piece_length: float

    @classmethod
    def from_json(cls, path: str):
        with open(path, "r") as f:
            config = json.load(f)
        return cls(**config)


def mix_volume(velocity: int, volume: int, expression: int) -> float:
    return velocity * (volume / 127.0) * (expression / 127.0)


def convert_midi_to_str(cfg: VocabConfig, filter_cfg: FilterConfig, mid: mido.MidiFile, augment: AugmentValues = None) -> List[str]:
    utils = VocabUtils(cfg)
    if augment is None:
        augment = AugmentValues.default()

    # filter out unknown meta messages before merge (https://github.com/mido/mido/pull/286)
    for i in range(len(mid.tracks)):
        mid.tracks[i] = [msg for msg in mid.tracks[i] if msg.type != "unknown_meta"]

    if len(mid.tracks) > 1:
        mid.tracks = [mido.merge_tracks(mid.tracks)]

    delta_time_ms = 0.0
    tempo = 500000
    channel_program = {i: 0 for i in range(16)}
    channel_volume = {i: 127 for i in range(16)}
    channel_expression = {i: 127 for i in range(16)}  # unlikely to be useful. expression usually modifies an already played note.
    channel_notes = {i: {} for i in range(16)}
    channel_pedal_on = {i: False for i in range(16)}
    channel_pedal_events = {i: {} for i in range(16)}  # {channel: {(note, program) -> True}}
    started_flag = False

    output_list = []
    output = ["<start>"]
    output_length_ms = 0.0
    token_data_buffer: List[Tuple[int, int, int, float]] = []  # need to sort notes between wait tokens

    def flush_token_data_buffer():
        nonlocal token_data_buffer, output, cfg, utils, augment
        token_data = [x for x in utils.prog_data_list_to_token_data_list(token_data_buffer)]
        if augment.instrument_bin_remap or augment.transpose_semitones:
            # TODO put transpose in a real function
            raw_transpose = lambda bin, n: n + augment.transpose_semitones if bin != cfg._ch10_bin_int else n
            octave_shift_if_oob = lambda n: n + 12 if n < 0 else n - 12 if n >= cfg.note_events else n
            # TODO handle ranges beyond 12
            #octave_shift_if_oob = lambda n: 0 if n < 0 else (n - cfg.note_events) % 12 + cfg.note_events if n >= cfg.note_events else n
            transpose = lambda bin, n: octave_shift_if_oob(raw_transpose(bin, n))

            token_data = [(augment.instrument_bin_remap.get(i, i), transpose(i, n), v) for i, n, v in token_data]
        if cfg.do_token_sorting:
            token_data = utils.sort_token_data(token_data)
        if cfg.unrolled_tokens:
            for t in token_data:
                output += [utils.format_unrolled_instrument_bin(t[0]), utils.format_unrolled_note(t[1]), utils.format_unrolled_velocity(t[2])]
        else:
            output += [utils.format_note_token(*t) for t in token_data]
        token_data_buffer = []

    def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
        nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
        is_token_valid = utils.prog_data_to_token_data(prog, chan, note, vel) is not None
        if not is_token_valid:
            return
        
        if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
            # check if any notes are still held
            silent = True
            for channel in channel_notes.keys():
                if len(channel_notes[channel]) > 0:
                    silent = False
                    break
            if silent:
                flush_token_data_buffer()
                output.append("<end>")
                if output_length_ms > filter_cfg.min_piece_length * 1000.0:
                    output_list.append(" ".join(output))
                output = ["<start>"]
                output_length_ms = 0.0
                started_flag = False
        if started_flag:
            wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
            if len(wait_tokens) > 0:
                flush_token_data_buffer()
                output_length_ms += delta_time_ms
                output += wait_tokens
        delta_time_ms = 0.0
        token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
        started_flag = True

    for msg in mid.tracks[0]:
        time_ms = mido.tick2second(msg.time, mid.ticks_per_beat, tempo) * 1000.0
        delta_time_ms += time_ms
        t = msg.type

        if msg.is_meta:
            if t == "set_tempo":
                tempo = msg.tempo * augment.time_stretch_factor
            continue

        def handle_note_off(ch, prog, n):
            if channel_pedal_on[ch]:
                channel_pedal_events[ch][(n, prog)] = True
            else:
                consume_note_program_data(prog, ch, n, 0)
                if n in channel_notes[ch]:
                    del channel_notes[ch][n]

        if t == "program_change":
            channel_program[msg.channel] = msg.program
        elif t == "note_on":
            if msg.velocity == 0:
                handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
            else:
                if (msg.note, channel_program[msg.channel]) in channel_pedal_events[msg.channel]:
                    del channel_pedal_events[msg.channel][(msg.note, channel_program[msg.channel])]
                consume_note_program_data(
                    channel_program[msg.channel],
                    msg.channel,
                    msg.note,
                    mix_volume(msg.velocity, channel_volume[msg.channel], channel_expression[msg.channel]),
                )
                channel_notes[msg.channel][msg.note] = True
        elif t == "note_off":
            handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
        elif t == "control_change":
            if msg.control == 7 or msg.control == 39:  # volume
                channel_volume[msg.channel] = msg.value
            elif msg.control == 11:  # expression
                channel_expression[msg.channel] = msg.value
            elif msg.control == 64:  # sustain pedal
                channel_pedal_on[msg.channel] = msg.value >= 64
                if not channel_pedal_on[msg.channel]:
                    for (note, program) in channel_pedal_events[msg.channel]:
                        handle_note_off(msg.channel, program, note)
                    channel_pedal_events[msg.channel] = {}
            elif msg.control == 123:  # all notes off
                for channel in channel_notes.keys():
                    for note in list(channel_notes[channel]).copy():
                        handle_note_off(channel, channel_program[channel], note)
        else:
            pass

    flush_token_data_buffer()
    output.append("<end>")
    if output_length_ms > filter_cfg.min_piece_length * 1000.0:
        output_list.append(" ".join(output))
    return output_list


def generate_program_change_messages(cfg: VocabConfig):
    for bin_name, channel in cfg.bin_channel_map.items():
        if channel == 9:
            continue
        program = cfg._instrument_names_str_to_int[cfg.bin_name_to_program_name[bin_name]]
        yield mido.Message("program_change", program=program, time=0, channel=channel)
    yield mido.Message("program_change", program=0, time=0, channel=9)


@dataclass
class DecodeState:
    total_time: float  # milliseconds
    delta_accum: float  # milliseconds
    current_bin: int
    current_note: int
    active_notes: Dict[Tuple[int, int], float]  # { (channel, note): time started, ... }


def token_to_midi_message(utils: VocabUtils, token: str, state: DecodeState, end_token_pause: float = 3.0) -> Iterator[Tuple[Optional[mido.Message], DecodeState]]:
    if state is None:
        state = DecodeState(total_time=0.0, delta_accum=0.0, current_bin=utils.cfg._short_instrument_names_str_to_int[utils.cfg.short_instr_bin_names[0]], current_note=0, active_notes={})
    token = token.strip()
    if not token:
        yield None, state
        return
    if token == "<end>":
        d = end_token_pause * 1000.0
        state.delta_accum += d
        state.total_time += d
        if utils.cfg.decode_end_held_note_delay != 0.0:
            # end held notes
            for (channel, note), start_time in list(state.active_notes.items()).copy():
                ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
                state.delta_accum = 0.0
                del state.active_notes[(channel, note)]
                yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
        yield None, state
        return
    if token.startswith("<"):
        yield None, state
        return
    
    if utils.cfg.unrolled_tokens:
        if token[0] == "t":
            d = utils.wait_token_to_delta(token)
            state.delta_accum += d
            state.total_time += d
        elif token[0] == "n":
            state.current_note = int(token[1:], base=16)
        elif token[0] == "i":
            state.current_bin = utils.cfg._short_instrument_names_str_to_int[token[1:]]
        elif token[0] == "v":
            current_velocity = utils.bin_to_velocity(int(token[1:], base=16))
            channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[state.current_bin]]
            ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
            state.delta_accum = 0.0
            if current_velocity > 0:
                yield mido.Message("note_on", note=state.current_note, velocity=current_velocity, time=ticks, channel=channel), state
            else:
                yield mido.Message("note_off", note=state.current_note, velocity=0, time=ticks, channel=channel), state
    else:
        if token[0] == "t" and token[1].isdigit():  # wait token
            d = utils.wait_token_to_delta(token)
            state.delta_accum += d
            state.total_time += d
            if utils.cfg.decode_end_held_note_delay != 0.0:
                # remove notes that have been held for too long
                for (channel, note), start_time in list(state.active_notes.items()).copy():
                    if state.total_time - start_time > utils.cfg.decode_end_held_note_delay * 1000.0:
                        ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
                        state.delta_accum = 0.0
                        del state.active_notes[(channel, note)]
                        yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
                        return
        else:  # note token
            bin, note, velocity = utils.note_token_to_data(token)
            channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[bin]]
            ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
            state.delta_accum = 0.0
            if velocity > 0:
                if utils.cfg.decode_fix_repeated_notes:
                    if (channel, note) in state.active_notes:
                        del state.active_notes[(channel, note)]
                        yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
                        ticks = 0
                state.active_notes[(channel, note)] = state.total_time
                yield mido.Message("note_on", note=note, velocity=velocity, time=ticks, channel=channel), state
                return
            else:
                if (channel, note) in state.active_notes:
                    del state.active_notes[(channel, note)]
                yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
                return
    yield None, state


def str_to_midi_messages(utils: VocabUtils, data: str) -> Iterator[mido.Message]:
    state = None
    for token in data.split(" "):
        for msg, new_state in token_to_midi_message(utils, token, state):
            state = new_state
            if msg is not None:
                yield msg


def convert_str_to_midi(cfg: VocabConfig, data: str, meta_text: str = "Generated by MIDI-LLM-tokenizer") -> mido.MidiFile:
    utils = VocabUtils(cfg)
    mid = mido.MidiFile()
    track = mido.MidiTrack()
    mid.tracks.append(track)

    tempo = 500000
    if meta_text:
        track.append(mido.MetaMessage("text", text=meta_text, time=0))
    track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0))
    for msg in generate_program_change_messages(cfg):
        track.append(msg)

    #data = data.replace("<start>", "").replace("<end>", "").replace("<pad>", "").strip()
    for msg in str_to_midi_messages(utils, data):
        track.append(msg)
    
    track.append(mido.MetaMessage("end_of_track", time=0))

    return mid