File size: 8,104 Bytes
572abf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Utilities for operating on encoded Midi sequences.
"""

from collections import defaultdict

from anticipation.config import *
from anticipation.vocab import *


def print_tokens(tokens):
    print('---------------------')
    for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
        if note == SEPARATOR:
            assert tm == SEPARATOR and dur == SEPARATOR
            print(j, 'SEPARATOR')
            continue

        if note == REST:
            assert tm < CONTROL_OFFSET
            assert dur == DUR_OFFSET+0
            print(j, tm, 'REST')
            continue

        if note < CONTROL_OFFSET:
            tm = tm - TIME_OFFSET
            dur = dur - DUR_OFFSET
            note = note - NOTE_OFFSET
            instr = note//2**7
            pitch = note - (2**7)*instr
            print(j, tm, dur, instr, pitch)
        else:
            tm = tm - ATIME_OFFSET
            dur = dur - ADUR_OFFSET
            note = note - ANOTE_OFFSET
            instr = note//2**7
            pitch = note - (2**7)*instr
            print(j, tm, dur, instr, pitch, '(A)')


def clip(tokens, start, end, clip_duration=True, seconds=True):
    if seconds:
        start = int(TIME_RESOLUTION*start)
        end = int(TIME_RESOLUTION*end)

    new_tokens = []
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note < CONTROL_OFFSET:
            this_time = time - TIME_OFFSET
            this_dur = dur - DUR_OFFSET
        else:
            this_time = time - ATIME_OFFSET
            this_dur = dur - ADUR_OFFSET

        if this_time < start or end < this_time:
            continue

        # truncate extended notes
        if clip_duration and end < this_time + this_dur:
            dur -= this_time + this_dur - end

        new_tokens.extend([time, dur, note])

    return new_tokens


def mask(tokens, start, end):
    new_tokens = []
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note < CONTROL_OFFSET:
            this_time = (time - TIME_OFFSET)/float(TIME_RESOLUTION)
        else:
            this_time = (time - ATIME_OFFSET)/float(TIME_RESOLUTION)

        if start < this_time < end:
            continue

        new_tokens.extend([time, dur, note])

    return new_tokens


def delete(tokens, criterion):
    new_tokens = []
    for token in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if criterion(token):
            continue

        new_tokens.extend(token)

    return new_tokens


def sort(tokens):
    """ sort sequence of events or controls (but not both) """

    times = tokens[0::3]
    indices = sorted(range(len(times)), key=times.__getitem__)

    sorted_tokens = []
    for idx in indices:
        sorted_tokens.extend(tokens[3*idx:3*(idx+1)])

    return sorted_tokens


def split(tokens):
    """ split a sequence into events and controls """

    events = []
    controls = []
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note < CONTROL_OFFSET:
            events.extend([time, dur, note])
        else:
            controls.extend([time, dur, note])

    return events, controls


def pad(tokens, end_time=None, density=TIME_RESOLUTION):
    end_time = TIME_OFFSET+(end_time if end_time else max_time(tokens, seconds=False))
    new_tokens = []
    previous_time = TIME_OFFSET+0
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        # must pad before separation, anticipation
        assert note < CONTROL_OFFSET

        # insert pad tokens to ensure the desired density
        while time > previous_time + density:
            new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
            previous_time += density

        new_tokens.extend([time, dur, note])
        previous_time = time

    while end_time > previous_time + density:
        new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
        previous_time += density

    return new_tokens


def unpad(tokens):
    new_tokens = []
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note == REST: continue

        new_tokens.extend([time, dur, note])

    return new_tokens


def anticipate(events, controls, delta=DELTA*TIME_RESOLUTION):
    """
    Interleave a sequence of events with anticipated controls.

    Inputs:
      events   : a sequence of events
      controls : a sequence of time-localized controls
      delta    : the anticipation interval
    
    Returns:
      tokens   : interleaved events and anticipated controls
      controls : unconsumed controls (control time > max_time(events) + delta)
    """

    if len(controls) == 0:
        return events, controls

    tokens = []
    event_time = 0
    control_time = controls[0] - ATIME_OFFSET
    for time, dur, note in zip(events[0::3],events[1::3],events[2::3]):
        while event_time >= control_time - delta:
            tokens.extend(controls[0:3])
            controls = controls[3:] # consume this control
            control_time = controls[0] - ATIME_OFFSET if len(controls) > 0 else float('inf')

        assert note < CONTROL_OFFSET
        event_time = time - TIME_OFFSET
        tokens.extend([time, dur, note])

    return tokens, controls


def sparsity(tokens):
    max_dt = 0
    previous_time = TIME_OFFSET+0
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note == SEPARATOR: continue
        assert note < CONTROL_OFFSET # don't operate on interleaved sequences

        max_dt = max(max_dt, time - previous_time)
        previous_time = time

    return max_dt


def min_time(tokens, seconds=True, instr=None):
    mt = None
    for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        # stop calculating at sequence separator
        if note == SEPARATOR: break

        if note < CONTROL_OFFSET:
            time -= TIME_OFFSET
            note -= NOTE_OFFSET
        else:
            time -= ATIME_OFFSET
            note -= ANOTE_OFFSET

        # min time of a particular instrument
        if instr is not None and instr != note//2**7:
            continue

        mt = time if mt is None else min(mt, time)

    if mt is None: mt = 0
    return mt/float(TIME_RESOLUTION) if seconds else mt


def max_time(tokens, seconds=True, instr=None):
    mt = 0
    for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        # keep checking for max_time, even if it appears after a separator
        # (this is important because we use this check for vocab overflow in tokenization)
        if note == SEPARATOR: continue

        if note < CONTROL_OFFSET:
            time -= TIME_OFFSET
            note -= NOTE_OFFSET
        else:
            time -= ATIME_OFFSET
            note -= ANOTE_OFFSET

        # max time of a particular instrument
        if instr is not None and instr != note//2**7:
            continue

        mt = max(mt, time)

    return mt/float(TIME_RESOLUTION) if seconds else mt


def get_instruments(tokens):
    instruments = defaultdict(int)
    for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        if note >= SPECIAL_OFFSET: continue

        if note < CONTROL_OFFSET:
            note -= NOTE_OFFSET
        else:
            note -= ANOTE_OFFSET

        instr = note//2**7
        instruments[instr] += 1

    return instruments


def translate(tokens, dt, seconds=False):
    if seconds:
        dt = int(TIME_RESOLUTION*dt)

    new_tokens = []
    for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
        # stop translating after EOT
        if note == SEPARATOR:
            new_tokens.extend([time, dur, note])
            dt = 0
            continue

        if note < CONTROL_OFFSET:
            this_time = time - TIME_OFFSET
        else:
            this_time = time - ATIME_OFFSET

        assert 0 <= this_time + dt
        new_tokens.extend([time+dt, dur, note])

    return new_tokens

def combine(events, controls):
    return sort(events + [token - CONTROL_OFFSET for token in controls])