File size: 17,948 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
""" tokenizer.py: Encodes and decodes events to/from tokens. """
import numpy as np
import warnings
from abc import ABC, abstractmethod
from utils.note_event_dataclasses import Event, EventRange, Note  #, Codec
from utils.event_codec import FastCodec as Codec
from utils.note_event_dataclasses import NoteEvent
from utils.note2event import note_event2event
from utils.event2note import event2note_event, note_event2note
from typing import List, Optional, Union, Tuple, Dict, Counter


#TODO: Too complex to be an abstract class.
class EventTokenizerBase(ABC):
    """
    A base class for encoding and decoding events to and from tokens.
    """

    def __init__(
        self,
        base_codec: Union[Codec, str] = 'mt3',
        special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
        extra_tokens: List[str] = [],
        max_shift_steps: int = 206,  # 1001 in Gardner et al.
        program_vocabulary: Optional[Dict] = None,
        drum_vocabulary: Optional[Dict] = None,
    ) -> None:
        """
        Initializes the EventTokenizerBase object.

        :param base_codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.
        """
        # Initialize the codec attribute based on the input codec parameter.
        if isinstance(base_codec, str):
            # If codec is a string, initialize codec with the appropriate Codec object.
            if base_codec.lower() == 'mt3':
                event_ranges = [
                    EventRange('pitch', min_value=0, max_value=127),
                    EventRange('velocity', min_value=0, max_value=1),
                    EventRange('tie', min_value=0, max_value=0),
                    EventRange('program', min_value=0, max_value=127),
                    EventRange('drum', min_value=0, max_value=127),
                ]
            else:
                raise ValueError(f'Unknown codec name: {base_codec}')

            # Initialize codec
            self.codec = Codec(special_tokens=special_tokens + extra_tokens,
                               max_shift_steps=max_shift_steps,
                               event_ranges=event_ranges,
                               program_vocabulary=program_vocabulary,
                               drum_vocabulary=drum_vocabulary,
                               name='mt3')

        elif isinstance(base_codec, Codec):
            # If codec is a Codec object, store it directly.
            self.codec = base_codec
            if program_vocabulary is not None or drum_vocabulary is not None:
                print('')
                warnings.warn("Vocabulary cannot be applied when using a custom codec.")
        else:
            # If codec is neither a string nor a Codec object, raise a NotImplementedError.
            raise TypeError(f'Unknown codec type: {type(base_codec)}')
        self.num_tokens = self.codec._num_classes

    def _encode(self, events: List[Event]) -> List[int]:
        return [self.codec.encode_event(e) for e in events]

    def _decode(self, tokens: List[int]) -> List[Event]:
        return [self.codec.decode_event_index(idx) for idx in tokens]

    @abstractmethod
    def encode(self):
        """ Encode your custom events to tokens. """
        pass

    @abstractmethod
    def decode(self):
        """ Decode your custom tokens to events."""
        pass


class EventTokenizer(EventTokenizerBase):
    """
    Eencoding and decoding events to and from tokens.
    """

    def __init__(self,
                 base_codec: Union[Codec, str] = 'mt3',
                 special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
                 extra_tokens: List[str] = [],
                 max_shift_steps: int = 206,
                 program_vocabulary: Optional[Dict] = None,
                 drum_vocabulary: Optional[Dict] = None) -> None:
        """
        Initializes the EventTokenizerBase object.

        :param codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.
        """
        # Initialize the codec attribute based on the input codec parameter.
        super().__init__(
            base_codec=base_codec,
            special_tokens=special_tokens,
            extra_tokens=extra_tokens,
            max_shift_steps=max_shift_steps,
            program_vocabulary=program_vocabulary,
            drum_vocabulary=drum_vocabulary,
        )

    def encode(self, events):
        """ Encode your custom events to tokens. """
        return super()._encode(events)

    def decode(self, tokens):
        """ Decode your custom tokens to events."""
        return super()._decode(tokens)


class NoteEventTokenizer(EventTokenizerBase):
    """ Encodes and decodes note events to/from tokens. """

    def __init__(
            self,
            base_codec: Union[Codec, str] = 'mt3',
            max_length: int = 1024,  # max length of tokens 
            tps: int = 100,
            sort_note_event: bool = True,
            special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
            extra_tokens: List[str] = [],
            max_shift_steps: int = 206,
            program_vocabulary: Optional[Dict] = None,
            drum_vocabulary: Optional[Dict] = None,
            ignore_decoding_tokens: List[str] = [],
            ignore_decoding_tokens_from_and_to: Optional[List[str]] = None,
            debug_mode: bool = False) -> None:
        """
        Initializes the TaskEventNoteTokenizer object.

        List[NoteEvent] -> encdoe_note_events -> np.ndarray[int]

        np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]]
                             
        :param codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.

        :param ignore_decoding_tokens: List of tokens to ignore during decoding.
        :param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to]
        """
        super().__init__(base_codec=base_codec,
                         special_tokens=special_tokens,
                         extra_tokens=extra_tokens,
                         max_shift_steps=max_shift_steps,
                         program_vocabulary=program_vocabulary,
                         drum_vocabulary=drum_vocabulary)
        self.max_length = max_length
        self.tps = tps
        self.sort = sort_note_event

        # Prepare prefix, suffix and pad tokens.
        self._prefix = []
        self._suffix = []
        for stk in self.codec.special_tokens:
            if stk == 'EOS':
                self._suffix.append(self.codec.special_tokens.index('EOS'))
            elif stk == 'PAD':
                self._zero_pad = [0] * 1024
            elif stk == 'UNK':
                pass
            else:
                pass
                # raise NotImplementedError(f'Unknown special token: {stk}')
        self.eos_id = self.codec.special_tokens.index('EOS')
        self.pad_id = self.codec.special_tokens.index('PAD')
        self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens]
        self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to
        self.debug_mode = debug_mode

    def _decode(self, tokens):
        # This is event detokenizer, not note_event. It is required for displaying events in validation dashboard
        return super()._decode(tokens)

    def encode(
        self,
        note_events: List[NoteEvent],
        tie_note_events: Optional[List[NoteEvent]] = None,
        start_time: float = 0.,
    ) -> List[int]:
        """ Encodes note events and tie note events to tokens. """
        events = note_event2event(
            note_events=note_events,
            tie_note_events=tie_note_events,
            start_time=start_time,  # required for calcuating relative time
            tps=self.tps,
            sort=self.sort)
        return super()._encode(events)

    def encode_plus(
            self,
            note_events: List[NoteEvent],
            tie_note_events: Optional[List[NoteEvent]] = None,
            start_times: float = 0.,  # Fixing bug: start_time --> start_times 
            add_special_tokens: Optional[bool] = True,
            max_length: Optional[int] = None,  #  if None, use self.max_length
            pad_to_max_length: Optional[bool] = True,
            return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]:
        """ Encodes note events and tie note info to padded tokens. """
        encoded = self.encode(note_events, tie_note_events, start_times)

        # if task_events:
        #     encoded = super()._encode(task_events) + encoded
        if add_special_tokens:
            if self._prefix:
                encoded = self._prefix + encoded
            if self._suffix:
                encoded = encoded + self._suffix

        if max_length is None:
            max_length = self.max_length

        length = len(encoded)
        if length >= max_length:
            encoded = encoded[:max_length]
            length = max_length

        if return_attention_mask:
            attention_mask = [1] * length

        # <PAD>
        if pad_to_max_length is True:
            if len(self._zero_pad) != max_length:
                self._zero_pad = [self.pad_id] * max_length
            if return_attention_mask:
                attention_mask += self._zero_pad[length:]
            encoded = encoded + self._zero_pad[length:]

        if return_attention_mask:
            return encoded, attention_mask

        return encoded

    def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]:
        # NOTE: This is an event tokenizer that generates task ids, not the list of note_event objects.
        encoded = super()._encode(task_events)

        # <PAD>
        if max_length is not None:
            if len(self._zero_pad_task) != max_length:
                self._zero_pad_task = [self.pad_id] * max_length
            length = len(encoded)
            encoded = encoded + self._zero_pad[length:]

        return encoded

    def decode(
        self,
        tokens: List[int],
        start_time: float = 0.,
        return_events: bool = False,
    ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]],
                                                              List[Event], int]]:
        """Decodes a sequence of tokens into note events.

        Args:
            tokens (List[int]): The list of tokens to be decoded.
            start_time (float, optional): The starting time for the note events. Defaults to 0.
            return_events (bool, optional): Indicates whether to include the raw events in the return value.
                                            Defaults to False.

        Returns:
            Union[Tuple[List[NoteEvent], List[NoteEvent]],
                Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events.
            If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, and `err_cnt`.
            If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, `events`, and `err_cnt`.
        """
        if self.debug_mode:
            ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding]
            print(ignored_tokens_from_input)

        if self.ids_to_ignore_decoding:
            tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding]

        events = super()._decode(tokens)
        note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps)
        if return_events:
            return note_events, tie_note_events, last_activity, events, err_cnt
        else:
            return note_events, tie_note_events, last_activity, err_cnt

    def decode_batch(
        self,
        batch_tokens: Union[List[List[int]], np.ndarray],
        start_times: List[float],
        return_events: bool = False
    ) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int],
               Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]],
                     Counter[str]]]:
        """ 
        Decodes a batch of tokens to note_events and tie_note_events.

        Args:
            batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded.
            start_times (List[float]): List of start times for each token set.
            return_events (bool, optional): Flag to determine if events should be returned. Defaults to False.

        """
        if isinstance(batch_tokens, np.ndarray):
            batch_tokens = batch_tokens.tolist()

        if len(batch_tokens) != len(start_times):
            raise ValueError('The length of batch_tokens and start_times must be same.')

        zipped_note_events_and_tie = []
        list_events = []
        total_err_cnt = 0

        for tokens, start_time in zip(batch_tokens, start_times):
            if return_events:
                note_events, tie_note_events, last_activity, events, err_cnt = self.decode(
                    tokens, start_time, return_events)
                list_events.append(events)
            else:
                note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events)

            zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time))
            total_err_cnt += err_cnt

        if return_events:
            return zipped_note_events_and_tie, list_events, total_err_cnt
        else:
            return zipped_note_events_and_tie, total_err_cnt

    def decode_list_batches(
        self,
        list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]],
        list_start_times: Union[List[List[float]], List[float]],
        return_events: bool = False
    ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]],
               Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]],
                     List[List[Event]], Counter[str]]]:
        """ 
        Decodes a list of variable-size batches of token array to a list of
        zipped note_events and tie_note_events.

        Args:
            list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length)
            list_start_times: List[float], where the length is sum of all batch_sizes.
            return_events: bool, Defaults to False.

        Returns:
            list_list_zipped_note_events_and_tie:
                List[
                    Tuple[
                        List[NoteEvent]: A list of note events.
                        List[NoteEvent]: A list of tie note events.
                        List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful
                            for validating notes within a batch of segments extracted from a file.
                        List[float]: A list of segment start times.
                    ]
                ]
            (Optional) list_events:
                List[List[Event]]
            total_err_cnt:
                Counter[str]: error counter.
        """
        list_tokens = []
        for arr in list_batch_tokens:
            for tokens in arr:
                list_tokens.append(tokens)
        assert (len(list_tokens) == len(list_start_times))

        zipped_note_events_and_tie = []
        list_events = []
        total_err_cnt = Counter()
        for tokens, start_time in zip(list_tokens, list_start_times):
            note_events, tie_note_events, last_activity, events, err_cnt = self.decode(
                tokens, start_time, return_events)
            zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time))
            if return_events:
                list_events.append(events)
            total_err_cnt += err_cnt

        if return_events:
            return zipped_note_events_and_tie, list_events, total_err_cnt
        else:
            return zipped_note_events_and_tie, total_err_cnt