File size: 17,423 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions that operate on NoteSequence protos."""

import dataclasses
import itertools

from typing import MutableMapping, MutableSet, Optional, Sequence, Tuple

from mt3 import event_codec
from mt3 import run_length_encoding
from mt3 import vocabularies

import note_seq

DEFAULT_VELOCITY = 100
DEFAULT_NOTE_DURATION = 0.01

# Quantization can result in zero-length notes; enforce a minimum duration.
MIN_NOTE_DURATION = 0.01


@dataclasses.dataclass
class TrackSpec:
  name: str
  program: int = 0
  is_drum: bool = False


def extract_track(ns, program, is_drum):
  track = note_seq.NoteSequence(ticks_per_quarter=220)
  track_notes = [note for note in ns.notes
                 if note.program == program and note.is_drum == is_drum]
  track.notes.extend(track_notes)
  track.total_time = (max(note.end_time for note in track.notes)
                      if track.notes else 0.0)
  return track


def trim_overlapping_notes(ns: note_seq.NoteSequence) -> note_seq.NoteSequence:
  """Trim overlapping notes from a NoteSequence, dropping zero-length notes."""
  ns_trimmed = note_seq.NoteSequence()
  ns_trimmed.CopyFrom(ns)
  channels = set((note.pitch, note.program, note.is_drum)
                 for note in ns_trimmed.notes)
  for pitch, program, is_drum in channels:
    notes = [note for note in ns_trimmed.notes if note.pitch == pitch
             and note.program == program and note.is_drum == is_drum]
    sorted_notes = sorted(notes, key=lambda note: note.start_time)
    for i in range(1, len(sorted_notes)):
      if sorted_notes[i - 1].end_time > sorted_notes[i].start_time:
        sorted_notes[i - 1].end_time = sorted_notes[i].start_time
  valid_notes = [note for note in ns_trimmed.notes
                 if note.start_time < note.end_time]
  del ns_trimmed.notes[:]
  ns_trimmed.notes.extend(valid_notes)
  return ns_trimmed


def assign_instruments(ns: note_seq.NoteSequence) -> None:
  """Assign instrument numbers to notes; modifies NoteSequence in place."""
  program_instruments = {}
  for note in ns.notes:
    if note.program not in program_instruments and not note.is_drum:
      num_instruments = len(program_instruments)
      note.instrument = (num_instruments if num_instruments < 9
                         else num_instruments + 1)
      program_instruments[note.program] = note.instrument
    elif note.is_drum:
      note.instrument = 9
    else:
      note.instrument = program_instruments[note.program]


def validate_note_sequence(ns: note_seq.NoteSequence) -> None:
  """Raise ValueError if NoteSequence contains invalid notes."""
  for note in ns.notes:
    if note.start_time >= note.end_time:
      raise ValueError('note has start time >= end time: %f >= %f' %
                       (note.start_time, note.end_time))
    if note.velocity == 0:
      raise ValueError('note has zero velocity')


def note_arrays_to_note_sequence(
    onset_times: Sequence[float],
    pitches: Sequence[int],
    offset_times: Optional[Sequence[float]] = None,
    velocities: Optional[Sequence[int]] = None,
    programs: Optional[Sequence[int]] = None,
    is_drums: Optional[Sequence[bool]] = None
) -> note_seq.NoteSequence:
  """Convert note onset / offset / pitch / velocity arrays to NoteSequence."""
  ns = note_seq.NoteSequence(ticks_per_quarter=220)
  for onset_time, offset_time, pitch, velocity, program, is_drum in itertools.zip_longest(
      onset_times, [] if offset_times is None else offset_times,
      pitches, [] if velocities is None else velocities,
      [] if programs is None else programs,
      [] if is_drums is None else is_drums):
    if offset_time is None:
      offset_time = onset_time + DEFAULT_NOTE_DURATION
    if velocity is None:
      velocity = DEFAULT_VELOCITY
    if program is None:
      program = 0
    if is_drum is None:
      is_drum = False
    ns.notes.add(
        start_time=onset_time,
        end_time=offset_time,
        pitch=pitch,
        velocity=velocity,
        program=program,
        is_drum=is_drum)
    ns.total_time = max(ns.total_time, offset_time)
  assign_instruments(ns)
  return ns


@dataclasses.dataclass
class NoteEventData:
  pitch: int
  velocity: Optional[int] = None
  program: Optional[int] = None
  is_drum: Optional[bool] = None
  instrument: Optional[int] = None


def note_sequence_to_onsets(
    ns: note_seq.NoteSequence
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
  """Extract note onsets and pitches from NoteSequence proto."""
  # Sort by pitch to use as a tiebreaker for subsequent stable sort.
  notes = sorted(ns.notes, key=lambda note: note.pitch)
  return ([note.start_time for note in notes],
          [NoteEventData(pitch=note.pitch) for note in notes])


def note_sequence_to_onsets_and_offsets(
    ns: note_seq.NoteSequence,
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
  """Extract onset & offset times and pitches from a NoteSequence proto.

  The onset & offset times will not necessarily be in sorted order.

  Args:
    ns: NoteSequence from which to extract onsets and offsets.

  Returns:
    times: A list of note onset and offset times.
    values: A list of NoteEventData objects where velocity is zero for note
        offsets.
  """
  # Sort by pitch and put offsets before onsets as a tiebreaker for subsequent
  # stable sort.
  notes = sorted(ns.notes, key=lambda note: note.pitch)
  times = ([note.end_time for note in notes] +
           [note.start_time for note in notes])
  values = ([NoteEventData(pitch=note.pitch, velocity=0) for note in notes] +
            [NoteEventData(pitch=note.pitch, velocity=note.velocity)
             for note in notes])
  return times, values


def note_sequence_to_onsets_and_offsets_and_programs(
    ns: note_seq.NoteSequence,
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
  """Extract onset & offset times and pitches & programs from a NoteSequence.

  The onset & offset times will not necessarily be in sorted order.

  Args:
    ns: NoteSequence from which to extract onsets and offsets.

  Returns:
    times: A list of note onset and offset times.
    values: A list of NoteEventData objects where velocity is zero for note
        offsets.
  """
  # Sort by program and pitch and put offsets before onsets as a tiebreaker for
  # subsequent stable sort.
  notes = sorted(ns.notes,
                 key=lambda note: (note.is_drum, note.program, note.pitch))
  times = ([note.end_time for note in notes if not note.is_drum] +
           [note.start_time for note in notes])
  values = ([NoteEventData(pitch=note.pitch, velocity=0,
                           program=note.program, is_drum=False)
             for note in notes if not note.is_drum] +
            [NoteEventData(pitch=note.pitch, velocity=note.velocity,
                           program=note.program, is_drum=note.is_drum)
             for note in notes])
  return times, values


@dataclasses.dataclass
class NoteEncodingState:
  """Encoding state for note transcription, keeping track of active pitches."""
  # velocity bin for active pitches and programs
  active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(
      default_factory=dict)


def note_event_data_to_events(
    state: Optional[NoteEncodingState],
    value: NoteEventData,
    codec: event_codec.Codec,
) -> Sequence[event_codec.Event]:
  """Convert note event data to a sequence of events."""
  if value.velocity is None:
    # onsets only, no program or velocity
    return [event_codec.Event('pitch', value.pitch)]
  else:
    num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
    velocity_bin = vocabularies.velocity_to_bin(
        value.velocity, num_velocity_bins)
    if value.program is None:
      # onsets + offsets + velocities only, no programs
      if state is not None:
        state.active_pitches[(value.pitch, 0)] = velocity_bin
      return [event_codec.Event('velocity', velocity_bin),
              event_codec.Event('pitch', value.pitch)]
    else:
      if value.is_drum:
        # drum events use a separate vocabulary
        return [event_codec.Event('velocity', velocity_bin),
                event_codec.Event('drum', value.pitch)]
      else:
        # program + velocity + pitch
        if state is not None:
          state.active_pitches[(value.pitch, value.program)] = velocity_bin
        return [event_codec.Event('program', value.program),
                event_codec.Event('velocity', velocity_bin),
                event_codec.Event('pitch', value.pitch)]


def note_encoding_state_to_events(
    state: NoteEncodingState
) -> Sequence[event_codec.Event]:
  """Output program and pitch events for active notes plus a final tie event."""
  events = []
  for pitch, program in sorted(
      state.active_pitches.keys(), key=lambda k: k[::-1]):
    if state.active_pitches[(pitch, program)]:
      events += [event_codec.Event('program', program),
                 event_codec.Event('pitch', pitch)]
  events.append(event_codec.Event('tie', 0))
  return events


@dataclasses.dataclass
class NoteDecodingState:
  """Decoding state for note transcription."""
  current_time: float = 0.0
  # velocity to apply to subsequent pitch events (zero for note-off)
  current_velocity: int = DEFAULT_VELOCITY
  # program to apply to subsequent pitch events
  current_program: int = 0
  # onset time and velocity for active pitches and programs
  active_pitches: MutableMapping[Tuple[int, int],
                                 Tuple[float, int]] = dataclasses.field(
                                     default_factory=dict)
  # pitches (with programs) to continue from previous segment
  tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field(
      default_factory=set)
  # whether or not we are in the tie section at the beginning of a segment
  is_tie_section: bool = False
  # partially-decoded NoteSequence
  note_sequence: note_seq.NoteSequence = dataclasses.field(
      default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220))


def decode_note_onset_event(
    state: NoteDecodingState,
    time: float,
    event: event_codec.Event,
    codec: event_codec.Codec,
) -> None:
  """Process note onset event and update decoding state."""
  if event.type == 'pitch':
    state.note_sequence.notes.add(
        start_time=time, end_time=time + DEFAULT_NOTE_DURATION,
        pitch=event.value, velocity=DEFAULT_VELOCITY)
    state.note_sequence.total_time = max(state.note_sequence.total_time,
                                         time + DEFAULT_NOTE_DURATION)
  else:
    raise ValueError('unexpected event type: %s' % event.type)


def _add_note_to_sequence(
    ns: note_seq.NoteSequence,
    start_time: float, end_time: float, pitch: int, velocity: int,
    program: int = 0, is_drum: bool = False
) -> None:
  end_time = max(end_time, start_time + MIN_NOTE_DURATION)
  ns.notes.add(
      start_time=start_time, end_time=end_time,
      pitch=pitch, velocity=velocity, program=program, is_drum=is_drum)
  ns.total_time = max(ns.total_time, end_time)


def decode_note_event(
    state: NoteDecodingState,
    time: float,
    event: event_codec.Event,
    codec: event_codec.Codec
) -> None:
  """Process note event and update decoding state."""
  if time < state.current_time:
    raise ValueError('event time < current time, %f < %f' % (
        time, state.current_time))
  state.current_time = time
  if event.type == 'pitch':
    pitch = event.value
    if state.is_tie_section:
      # "tied" pitch
      if (pitch, state.current_program) not in state.active_pitches:
        raise ValueError('inactive pitch/program in tie section: %d/%d' %
                         (pitch, state.current_program))
      if (pitch, state.current_program) in state.tied_pitches:
        raise ValueError('pitch/program is already tied: %d/%d' %
                         (pitch, state.current_program))
      state.tied_pitches.add((pitch, state.current_program))
    elif state.current_velocity == 0:
      # note offset
      if (pitch, state.current_program) not in state.active_pitches:
        raise ValueError('note-off for inactive pitch/program: %d/%d' %
                         (pitch, state.current_program))
      onset_time, onset_velocity = state.active_pitches.pop(
          (pitch, state.current_program))
      _add_note_to_sequence(
          state.note_sequence, start_time=onset_time, end_time=time,
          pitch=pitch, velocity=onset_velocity, program=state.current_program)
    else:
      # note onset
      if (pitch, state.current_program) in state.active_pitches:
        # The pitch is already active; this shouldn't really happen but we'll
        # try to handle it gracefully by ending the previous note and starting a
        # new one.
        onset_time, onset_velocity = state.active_pitches.pop(
            (pitch, state.current_program))
        _add_note_to_sequence(
            state.note_sequence, start_time=onset_time, end_time=time,
            pitch=pitch, velocity=onset_velocity, program=state.current_program)
      state.active_pitches[(pitch, state.current_program)] = (
          time, state.current_velocity)
  elif event.type == 'drum':
    # drum onset (drums have no offset)
    if state.current_velocity == 0:
      raise ValueError('velocity cannot be zero for drum event')
    offset_time = time + DEFAULT_NOTE_DURATION
    _add_note_to_sequence(
        state.note_sequence, start_time=time, end_time=offset_time,
        pitch=event.value, velocity=state.current_velocity, is_drum=True)
  elif event.type == 'velocity':
    # velocity change
    num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
    velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins)
    state.current_velocity = velocity
  elif event.type == 'program':
    # program change
    state.current_program = event.value
  elif event.type == 'tie':
    # end of tie section; end active notes that weren't declared tied
    if not state.is_tie_section:
      raise ValueError('tie section end event when not in tie section')
    for (pitch, program) in list(state.active_pitches.keys()):
      if (pitch, program) not in state.tied_pitches:
        onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
        _add_note_to_sequence(
            state.note_sequence,
            start_time=onset_time, end_time=state.current_time,
            pitch=pitch, velocity=onset_velocity, program=program)
    state.is_tie_section = False
  else:
    raise ValueError('unexpected event type: %s' % event.type)


def begin_tied_pitches_section(state: NoteDecodingState) -> None:
  """Begin the tied pitches section at the start of a segment."""
  state.tied_pitches = set()
  state.is_tie_section = True


def flush_note_decoding_state(
    state: NoteDecodingState
) -> note_seq.NoteSequence:
  """End all active notes and return resulting NoteSequence."""
  for onset_time, _ in state.active_pitches.values():
    state.current_time = max(state.current_time, onset_time + MIN_NOTE_DURATION)
  for (pitch, program) in list(state.active_pitches.keys()):
    onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
    _add_note_to_sequence(
        state.note_sequence, start_time=onset_time, end_time=state.current_time,
        pitch=pitch, velocity=onset_velocity, program=program)
  assign_instruments(state.note_sequence)
  return state.note_sequence


class NoteEncodingSpecType(run_length_encoding.EventEncodingSpec):
  pass


# encoding spec for modeling note onsets only
NoteOnsetEncodingSpec = NoteEncodingSpecType(
    init_encoding_state_fn=lambda: None,
    encode_event_fn=note_event_data_to_events,
    encoding_state_to_events_fn=None,
    init_decoding_state_fn=NoteDecodingState,
    begin_decoding_segment_fn=lambda state: None,
    decode_event_fn=decode_note_onset_event,
    flush_decoding_state_fn=lambda state: state.note_sequence)


# encoding spec for modeling onsets and offsets
NoteEncodingSpec = NoteEncodingSpecType(
    init_encoding_state_fn=lambda: None,
    encode_event_fn=note_event_data_to_events,
    encoding_state_to_events_fn=None,
    init_decoding_state_fn=NoteDecodingState,
    begin_decoding_segment_fn=lambda state: None,
    decode_event_fn=decode_note_event,
    flush_decoding_state_fn=flush_note_decoding_state)


# encoding spec for modeling onsets and offsets, with a "tie" section at the
# beginning of each segment listing already-active notes
NoteEncodingWithTiesSpec = NoteEncodingSpecType(
    init_encoding_state_fn=NoteEncodingState,
    encode_event_fn=note_event_data_to_events,
    encoding_state_to_events_fn=note_encoding_state_to_events,
    init_decoding_state_fn=NoteDecodingState,
    begin_decoding_segment_fn=begin_tied_pitches_section,
    decode_event_fn=decode_note_event,
    flush_decoding_state_fn=flush_note_decoding_state)