Spaces:
Build error
Build error
# Copyright 2022 The MT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Helper functions that operate on NoteSequence protos.""" | |
import dataclasses | |
import itertools | |
from typing import MutableMapping, MutableSet, Optional, Sequence, Tuple | |
from mt3 import event_codec | |
from mt3 import run_length_encoding | |
from mt3 import vocabularies | |
import note_seq | |
DEFAULT_VELOCITY = 100 | |
DEFAULT_NOTE_DURATION = 0.01 | |
# Quantization can result in zero-length notes; enforce a minimum duration. | |
MIN_NOTE_DURATION = 0.01 | |
class TrackSpec: | |
name: str | |
program: int = 0 | |
is_drum: bool = False | |
def extract_track(ns, program, is_drum): | |
track = note_seq.NoteSequence(ticks_per_quarter=220) | |
track_notes = [note for note in ns.notes | |
if note.program == program and note.is_drum == is_drum] | |
track.notes.extend(track_notes) | |
track.total_time = (max(note.end_time for note in track.notes) | |
if track.notes else 0.0) | |
return track | |
def trim_overlapping_notes(ns: note_seq.NoteSequence) -> note_seq.NoteSequence: | |
"""Trim overlapping notes from a NoteSequence, dropping zero-length notes.""" | |
ns_trimmed = note_seq.NoteSequence() | |
ns_trimmed.CopyFrom(ns) | |
channels = set((note.pitch, note.program, note.is_drum) | |
for note in ns_trimmed.notes) | |
for pitch, program, is_drum in channels: | |
notes = [note for note in ns_trimmed.notes if note.pitch == pitch | |
and note.program == program and note.is_drum == is_drum] | |
sorted_notes = sorted(notes, key=lambda note: note.start_time) | |
for i in range(1, len(sorted_notes)): | |
if sorted_notes[i - 1].end_time > sorted_notes[i].start_time: | |
sorted_notes[i - 1].end_time = sorted_notes[i].start_time | |
valid_notes = [note for note in ns_trimmed.notes | |
if note.start_time < note.end_time] | |
del ns_trimmed.notes[:] | |
ns_trimmed.notes.extend(valid_notes) | |
return ns_trimmed | |
def assign_instruments(ns: note_seq.NoteSequence) -> None: | |
"""Assign instrument numbers to notes; modifies NoteSequence in place.""" | |
program_instruments = {} | |
for note in ns.notes: | |
if note.program not in program_instruments and not note.is_drum: | |
num_instruments = len(program_instruments) | |
note.instrument = (num_instruments if num_instruments < 9 | |
else num_instruments + 1) | |
program_instruments[note.program] = note.instrument | |
elif note.is_drum: | |
note.instrument = 9 | |
else: | |
note.instrument = program_instruments[note.program] | |
def validate_note_sequence(ns: note_seq.NoteSequence) -> None: | |
"""Raise ValueError if NoteSequence contains invalid notes.""" | |
for note in ns.notes: | |
if note.start_time >= note.end_time: | |
raise ValueError('note has start time >= end time: %f >= %f' % | |
(note.start_time, note.end_time)) | |
if note.velocity == 0: | |
raise ValueError('note has zero velocity') | |
def note_arrays_to_note_sequence( | |
onset_times: Sequence[float], | |
pitches: Sequence[int], | |
offset_times: Optional[Sequence[float]] = None, | |
velocities: Optional[Sequence[int]] = None, | |
programs: Optional[Sequence[int]] = None, | |
is_drums: Optional[Sequence[bool]] = None | |
) -> note_seq.NoteSequence: | |
"""Convert note onset / offset / pitch / velocity arrays to NoteSequence.""" | |
ns = note_seq.NoteSequence(ticks_per_quarter=220) | |
for onset_time, offset_time, pitch, velocity, program, is_drum in itertools.zip_longest( | |
onset_times, [] if offset_times is None else offset_times, | |
pitches, [] if velocities is None else velocities, | |
[] if programs is None else programs, | |
[] if is_drums is None else is_drums): | |
if offset_time is None: | |
offset_time = onset_time + DEFAULT_NOTE_DURATION | |
if velocity is None: | |
velocity = DEFAULT_VELOCITY | |
if program is None: | |
program = 0 | |
if is_drum is None: | |
is_drum = False | |
ns.notes.add( | |
start_time=onset_time, | |
end_time=offset_time, | |
pitch=pitch, | |
velocity=velocity, | |
program=program, | |
is_drum=is_drum) | |
ns.total_time = max(ns.total_time, offset_time) | |
assign_instruments(ns) | |
return ns | |
class NoteEventData: | |
pitch: int | |
velocity: Optional[int] = None | |
program: Optional[int] = None | |
is_drum: Optional[bool] = None | |
instrument: Optional[int] = None | |
def note_sequence_to_onsets( | |
ns: note_seq.NoteSequence | |
) -> Tuple[Sequence[float], Sequence[NoteEventData]]: | |
"""Extract note onsets and pitches from NoteSequence proto.""" | |
# Sort by pitch to use as a tiebreaker for subsequent stable sort. | |
notes = sorted(ns.notes, key=lambda note: note.pitch) | |
return ([note.start_time for note in notes], | |
[NoteEventData(pitch=note.pitch) for note in notes]) | |
def note_sequence_to_onsets_and_offsets( | |
ns: note_seq.NoteSequence, | |
) -> Tuple[Sequence[float], Sequence[NoteEventData]]: | |
"""Extract onset & offset times and pitches from a NoteSequence proto. | |
The onset & offset times will not necessarily be in sorted order. | |
Args: | |
ns: NoteSequence from which to extract onsets and offsets. | |
Returns: | |
times: A list of note onset and offset times. | |
values: A list of NoteEventData objects where velocity is zero for note | |
offsets. | |
""" | |
# Sort by pitch and put offsets before onsets as a tiebreaker for subsequent | |
# stable sort. | |
notes = sorted(ns.notes, key=lambda note: note.pitch) | |
times = ([note.end_time for note in notes] + | |
[note.start_time for note in notes]) | |
values = ([NoteEventData(pitch=note.pitch, velocity=0) for note in notes] + | |
[NoteEventData(pitch=note.pitch, velocity=note.velocity) | |
for note in notes]) | |
return times, values | |
def note_sequence_to_onsets_and_offsets_and_programs( | |
ns: note_seq.NoteSequence, | |
) -> Tuple[Sequence[float], Sequence[NoteEventData]]: | |
"""Extract onset & offset times and pitches & programs from a NoteSequence. | |
The onset & offset times will not necessarily be in sorted order. | |
Args: | |
ns: NoteSequence from which to extract onsets and offsets. | |
Returns: | |
times: A list of note onset and offset times. | |
values: A list of NoteEventData objects where velocity is zero for note | |
offsets. | |
""" | |
# Sort by program and pitch and put offsets before onsets as a tiebreaker for | |
# subsequent stable sort. | |
notes = sorted(ns.notes, | |
key=lambda note: (note.is_drum, note.program, note.pitch)) | |
times = ([note.end_time for note in notes if not note.is_drum] + | |
[note.start_time for note in notes]) | |
values = ([NoteEventData(pitch=note.pitch, velocity=0, | |
program=note.program, is_drum=False) | |
for note in notes if not note.is_drum] + | |
[NoteEventData(pitch=note.pitch, velocity=note.velocity, | |
program=note.program, is_drum=note.is_drum) | |
for note in notes]) | |
return times, values | |
class NoteEncodingState: | |
"""Encoding state for note transcription, keeping track of active pitches.""" | |
# velocity bin for active pitches and programs | |
active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field( | |
default_factory=dict) | |
def note_event_data_to_events( | |
state: Optional[NoteEncodingState], | |
value: NoteEventData, | |
codec: event_codec.Codec, | |
) -> Sequence[event_codec.Event]: | |
"""Convert note event data to a sequence of events.""" | |
if value.velocity is None: | |
# onsets only, no program or velocity | |
return [event_codec.Event('pitch', value.pitch)] | |
else: | |
num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec) | |
velocity_bin = vocabularies.velocity_to_bin( | |
value.velocity, num_velocity_bins) | |
if value.program is None: | |
# onsets + offsets + velocities only, no programs | |
if state is not None: | |
state.active_pitches[(value.pitch, 0)] = velocity_bin | |
return [event_codec.Event('velocity', velocity_bin), | |
event_codec.Event('pitch', value.pitch)] | |
else: | |
if value.is_drum: | |
# drum events use a separate vocabulary | |
return [event_codec.Event('velocity', velocity_bin), | |
event_codec.Event('drum', value.pitch)] | |
else: | |
# program + velocity + pitch | |
if state is not None: | |
state.active_pitches[(value.pitch, value.program)] = velocity_bin | |
return [event_codec.Event('program', value.program), | |
event_codec.Event('velocity', velocity_bin), | |
event_codec.Event('pitch', value.pitch)] | |
def note_encoding_state_to_events( | |
state: NoteEncodingState | |
) -> Sequence[event_codec.Event]: | |
"""Output program and pitch events for active notes plus a final tie event.""" | |
events = [] | |
for pitch, program in sorted( | |
state.active_pitches.keys(), key=lambda k: k[::-1]): | |
if state.active_pitches[(pitch, program)]: | |
events += [event_codec.Event('program', program), | |
event_codec.Event('pitch', pitch)] | |
events.append(event_codec.Event('tie', 0)) | |
return events | |
class NoteDecodingState: | |
"""Decoding state for note transcription.""" | |
current_time: float = 0.0 | |
# velocity to apply to subsequent pitch events (zero for note-off) | |
current_velocity: int = DEFAULT_VELOCITY | |
# program to apply to subsequent pitch events | |
current_program: int = 0 | |
# onset time and velocity for active pitches and programs | |
active_pitches: MutableMapping[Tuple[int, int], | |
Tuple[float, int]] = dataclasses.field( | |
default_factory=dict) | |
# pitches (with programs) to continue from previous segment | |
tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field( | |
default_factory=set) | |
# whether or not we are in the tie section at the beginning of a segment | |
is_tie_section: bool = False | |
# partially-decoded NoteSequence | |
note_sequence: note_seq.NoteSequence = dataclasses.field( | |
default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220)) | |
def decode_note_onset_event( | |
state: NoteDecodingState, | |
time: float, | |
event: event_codec.Event, | |
codec: event_codec.Codec, | |
) -> None: | |
"""Process note onset event and update decoding state.""" | |
if event.type == 'pitch': | |
state.note_sequence.notes.add( | |
start_time=time, end_time=time + DEFAULT_NOTE_DURATION, | |
pitch=event.value, velocity=DEFAULT_VELOCITY) | |
state.note_sequence.total_time = max(state.note_sequence.total_time, | |
time + DEFAULT_NOTE_DURATION) | |
else: | |
raise ValueError('unexpected event type: %s' % event.type) | |
def _add_note_to_sequence( | |
ns: note_seq.NoteSequence, | |
start_time: float, end_time: float, pitch: int, velocity: int, | |
program: int = 0, is_drum: bool = False | |
) -> None: | |
end_time = max(end_time, start_time + MIN_NOTE_DURATION) | |
ns.notes.add( | |
start_time=start_time, end_time=end_time, | |
pitch=pitch, velocity=velocity, program=program, is_drum=is_drum) | |
ns.total_time = max(ns.total_time, end_time) | |
def decode_note_event( | |
state: NoteDecodingState, | |
time: float, | |
event: event_codec.Event, | |
codec: event_codec.Codec | |
) -> None: | |
"""Process note event and update decoding state.""" | |
if time < state.current_time: | |
raise ValueError('event time < current time, %f < %f' % ( | |
time, state.current_time)) | |
state.current_time = time | |
if event.type == 'pitch': | |
pitch = event.value | |
if state.is_tie_section: | |
# "tied" pitch | |
if (pitch, state.current_program) not in state.active_pitches: | |
raise ValueError('inactive pitch/program in tie section: %d/%d' % | |
(pitch, state.current_program)) | |
if (pitch, state.current_program) in state.tied_pitches: | |
raise ValueError('pitch/program is already tied: %d/%d' % | |
(pitch, state.current_program)) | |
state.tied_pitches.add((pitch, state.current_program)) | |
elif state.current_velocity == 0: | |
# note offset | |
if (pitch, state.current_program) not in state.active_pitches: | |
raise ValueError('note-off for inactive pitch/program: %d/%d' % | |
(pitch, state.current_program)) | |
onset_time, onset_velocity = state.active_pitches.pop( | |
(pitch, state.current_program)) | |
_add_note_to_sequence( | |
state.note_sequence, start_time=onset_time, end_time=time, | |
pitch=pitch, velocity=onset_velocity, program=state.current_program) | |
else: | |
# note onset | |
if (pitch, state.current_program) in state.active_pitches: | |
# The pitch is already active; this shouldn't really happen but we'll | |
# try to handle it gracefully by ending the previous note and starting a | |
# new one. | |
onset_time, onset_velocity = state.active_pitches.pop( | |
(pitch, state.current_program)) | |
_add_note_to_sequence( | |
state.note_sequence, start_time=onset_time, end_time=time, | |
pitch=pitch, velocity=onset_velocity, program=state.current_program) | |
state.active_pitches[(pitch, state.current_program)] = ( | |
time, state.current_velocity) | |
elif event.type == 'drum': | |
# drum onset (drums have no offset) | |
if state.current_velocity == 0: | |
raise ValueError('velocity cannot be zero for drum event') | |
offset_time = time + DEFAULT_NOTE_DURATION | |
_add_note_to_sequence( | |
state.note_sequence, start_time=time, end_time=offset_time, | |
pitch=event.value, velocity=state.current_velocity, is_drum=True) | |
elif event.type == 'velocity': | |
# velocity change | |
num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec) | |
velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins) | |
state.current_velocity = velocity | |
elif event.type == 'program': | |
# program change | |
state.current_program = event.value | |
elif event.type == 'tie': | |
# end of tie section; end active notes that weren't declared tied | |
if not state.is_tie_section: | |
raise ValueError('tie section end event when not in tie section') | |
for (pitch, program) in list(state.active_pitches.keys()): | |
if (pitch, program) not in state.tied_pitches: | |
onset_time, onset_velocity = state.active_pitches.pop((pitch, program)) | |
_add_note_to_sequence( | |
state.note_sequence, | |
start_time=onset_time, end_time=state.current_time, | |
pitch=pitch, velocity=onset_velocity, program=program) | |
state.is_tie_section = False | |
else: | |
raise ValueError('unexpected event type: %s' % event.type) | |
def begin_tied_pitches_section(state: NoteDecodingState) -> None: | |
"""Begin the tied pitches section at the start of a segment.""" | |
state.tied_pitches = set() | |
state.is_tie_section = True | |
def flush_note_decoding_state( | |
state: NoteDecodingState | |
) -> note_seq.NoteSequence: | |
"""End all active notes and return resulting NoteSequence.""" | |
for onset_time, _ in state.active_pitches.values(): | |
state.current_time = max(state.current_time, onset_time + MIN_NOTE_DURATION) | |
for (pitch, program) in list(state.active_pitches.keys()): | |
onset_time, onset_velocity = state.active_pitches.pop((pitch, program)) | |
_add_note_to_sequence( | |
state.note_sequence, start_time=onset_time, end_time=state.current_time, | |
pitch=pitch, velocity=onset_velocity, program=program) | |
assign_instruments(state.note_sequence) | |
return state.note_sequence | |
class NoteEncodingSpecType(run_length_encoding.EventEncodingSpec): | |
pass | |
# encoding spec for modeling note onsets only | |
NoteOnsetEncodingSpec = NoteEncodingSpecType( | |
init_encoding_state_fn=lambda: None, | |
encode_event_fn=note_event_data_to_events, | |
encoding_state_to_events_fn=None, | |
init_decoding_state_fn=NoteDecodingState, | |
begin_decoding_segment_fn=lambda state: None, | |
decode_event_fn=decode_note_onset_event, | |
flush_decoding_state_fn=lambda state: state.note_sequence) | |
# encoding spec for modeling onsets and offsets | |
NoteEncodingSpec = NoteEncodingSpecType( | |
init_encoding_state_fn=lambda: None, | |
encode_event_fn=note_event_data_to_events, | |
encoding_state_to_events_fn=None, | |
init_decoding_state_fn=NoteDecodingState, | |
begin_decoding_segment_fn=lambda state: None, | |
decode_event_fn=decode_note_event, | |
flush_decoding_state_fn=flush_note_decoding_state) | |
# encoding spec for modeling onsets and offsets, with a "tie" section at the | |
# beginning of each segment listing already-active notes | |
NoteEncodingWithTiesSpec = NoteEncodingSpecType( | |
init_encoding_state_fn=NoteEncodingState, | |
encode_event_fn=note_event_data_to_events, | |
encoding_state_to_events_fn=note_encoding_state_to_events, | |
init_decoding_state_fn=NoteDecodingState, | |
begin_decoding_segment_fn=begin_tied_pitches_section, | |
decode_event_fn=decode_note_event, | |
flush_decoding_state_fn=flush_note_decoding_state) | |