# Copyright 2022 The MT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for run length encoding.""" import dataclasses from typing import Any, Callable, Mapping, MutableMapping, Tuple, Optional, Sequence, TypeVar from absl import logging from mt3 import event_codec import numpy as np import seqio import tensorflow as tf Event = event_codec.Event # These should be type variables, but unfortunately those are incompatible with # dataclasses. EventData = Any EncodingState = Any DecodingState = Any DecodeResult = Any T = TypeVar('T', bound=EventData) ES = TypeVar('ES', bound=EncodingState) DS = TypeVar('DS', bound=DecodingState) @dataclasses.dataclass class EventEncodingSpec: """Spec for encoding events.""" # initialize encoding state init_encoding_state_fn: Callable[[], EncodingState] # convert EventData into zero or more events, updating encoding state encode_event_fn: Callable[[EncodingState, EventData, event_codec.Codec], Sequence[event_codec.Event]] # convert encoding state (at beginning of segment) into events encoding_state_to_events_fn: Optional[Callable[[EncodingState], Sequence[event_codec.Event]]] # create empty decoding state init_decoding_state_fn: Callable[[], DecodingState] # update decoding state when entering new segment begin_decoding_segment_fn: Callable[[DecodingState], None] # consume time and Event and update decoding state decode_event_fn: Callable[ [DecodingState, float, event_codec.Event, event_codec.Codec], None] # flush decoding state into result flush_decoding_state_fn: Callable[[DecodingState], DecodeResult] def encode_and_index_events( state: ES, event_times: Sequence[float], event_values: Sequence[T], encode_event_fn: Callable[[ES, T, event_codec.Codec], Sequence[event_codec.Event]], codec: event_codec.Codec, frame_times: Sequence[float], encoding_state_to_events_fn: Optional[ Callable[[ES], Sequence[event_codec.Event]]] = None, ) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: """Encode a sequence of timed events and index to audio frame times. Encodes time shifts as repeated single step shifts for later run length encoding. Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio frame. This can be used e.g. to prepend events representing the current state to a targets segment. Args: state: Initial event encoding state. event_times: Sequence of event times. event_values: Sequence of event values. encode_event_fn: Function that transforms event value into a sequence of one or more event_codec.Event objects. codec: An event_codec.Codec object that maps Event objects to indices. frame_times: Time for every audio frame. encoding_state_to_events_fn: Function that transforms encoding state into a sequence of one or more event_codec.Event objects. Returns: events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of another. event_end_indices: Corresponding end event index for every audio frame. Used to ensure when slicing that one chunk ends where the next begins. Should always be true that event_end_indices[i] = event_start_indices[i + 1]. state_events: Encoded "state" events representing the encoding state before each event. state_event_indices: Corresponding state event index for every audio frame. """ indices = np.argsort(event_times, kind='stable') event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] event_values = [event_values[i] for i in indices] events = [] state_events = [] event_start_indices = [] state_event_indices = [] cur_step = 0 cur_event_idx = 0 cur_state_event_idx = 0 def fill_event_start_indices_to_cur_step(): while(len(event_start_indices) < len(frame_times) and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second): event_start_indices.append(cur_event_idx) state_event_indices.append(cur_state_event_idx) for event_step, event_value in zip(event_steps, event_values): while event_step > cur_step: events.append(codec.encode_event(Event(type='shift', value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) cur_state_event_idx = len(state_events) if encoding_state_to_events_fn: # Dump state to state events *before* processing the next event, because # we want to capture the state prior to the occurrence of the event. for e in encoding_state_to_events_fn(state): state_events.append(codec.encode_event(e)) for e in encode_event_fn(state, event_value, codec): events.append(codec.encode_event(e)) # After the last event, continue filling out the event_start_indices array. # The inequality is not strict because if our current step lines up exactly # with (the start of) an audio frame, we need to add an additional shift event # to "cover" that frame. while cur_step / codec.steps_per_second <= frame_times[-1]: events.append(codec.encode_event(Event(type='shift', value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) # Now fill in event_end_indices. We need this extra array to make sure that # when we slice events, each slice ends exactly where the subsequent slice # begins. event_end_indices = event_start_indices[1:] + [len(events)] events = np.array(events) state_events = np.array(state_events) event_start_indices = np.array(event_start_indices) event_end_indices = np.array(event_end_indices) state_event_indices = np.array(state_event_indices) return (events, event_start_indices, event_end_indices, state_events, state_event_indices) @seqio.map_over_dataset def extract_target_sequence_with_indices(features, state_events_end_token=None): """Extract target sequence corresponding to audio token segment.""" target_start_idx = features['input_event_start_indices'][0] target_end_idx = features['input_event_end_indices'][-1] features['targets'] = features['targets'][target_start_idx:target_end_idx] if state_events_end_token is not None: # Extract the state events corresponding to the audio start token, and # prepend them to the targets array. state_event_start_idx = features['input_state_event_indices'][0] state_event_end_idx = state_event_start_idx + 1 while features['state_events'][ state_event_end_idx - 1] != state_events_end_token: state_event_end_idx += 1 features['targets'] = tf.concat([ features['state_events'][state_event_start_idx:state_event_end_idx], features['targets'] ], axis=0) return features def remove_redundant_state_changes_fn( codec: event_codec.Codec, feature_key: str = 'targets', state_change_event_types: Sequence[str] = () ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: """Return preprocessing function that removes redundant state change events. Args: codec: The event_codec.Codec used to interpret the events. feature_key: The feature key for which to remove redundant state changes. state_change_event_types: A list of event types that represent state changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones will be removed. Returns: A preprocessing function that removes redundant state change events. """ state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] def remove_redundant_state_changes( features: MutableMapping[str, Any], ) -> Mapping[str, Any]: """Remove redundant tokens e.g. duplicate velocity changes from sequence.""" current_state = tf.zeros(len(state_change_event_ranges), dtype=tf.int32) output = tf.constant([], dtype=tf.int32) for event in features[feature_key]: # Let autograph know that the shape of 'output' will change during the # loop. tf.autograph.experimental.set_loop_options( shape_invariants=[(output, tf.TensorShape([None]))]) is_redundant = False for i, (min_index, max_index) in enumerate(state_change_event_ranges): if (min_index <= event) and (event <= max_index): if current_state[i] == event: is_redundant = True current_state = tf.tensor_scatter_nd_update( current_state, indices=[[i]], updates=[event]) if not is_redundant: output = tf.concat([output, [event]], axis=0) features[feature_key] = output return features return seqio.map_over_dataset(remove_redundant_state_changes) def run_length_encode_shifts_fn( codec: event_codec.Codec, feature_key: str = 'targets' ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: """Return a function that run-length encodes shifts for a given codec. Args: codec: The Codec to use for shift events. feature_key: The feature key for which to run-length encode shifts. Returns: A preprocessing function that run-length encodes single-step shifts. """ def run_length_encode_shifts( features: MutableMapping[str, Any] ) -> Mapping[str, Any]: """Combine leading/interior shifts, trim trailing shifts. Args: features: Dict of features to process. Returns: A dict of features. """ events = features[feature_key] shift_steps = 0 total_shift_steps = 0 output = tf.constant([], dtype=tf.int32) for event in events: # Let autograph know that the shape of 'output' will change during the # loop. tf.autograph.experimental.set_loop_options( shape_invariants=[(output, tf.TensorShape([None]))]) if codec.is_shift_event_index(event): shift_steps += 1 total_shift_steps += 1 else: # Once we've reached a non-shift event, RLE all previous shift events # before outputting the non-shift event. if shift_steps > 0: shift_steps = total_shift_steps while shift_steps > 0: output_steps = tf.minimum(codec.max_shift_steps, shift_steps) output = tf.concat([output, [output_steps]], axis=0) shift_steps -= output_steps output = tf.concat([output, [event]], axis=0) features[feature_key] = output return features return seqio.map_over_dataset(run_length_encode_shifts) def merge_run_length_encoded_targets( targets: np.ndarray, codec: event_codec.Codec ) -> Sequence[int]: """Merge multiple tracks of target events into a single stream. Args: targets: A 2D array (# tracks by # events) of integer event values. codec: The event_codec.Codec used to interpret the events. Returns: A 1D array of merged events. """ num_tracks = tf.shape(targets)[0] targets_length = tf.shape(targets)[1] current_step = 0 current_offsets = tf.zeros(num_tracks, dtype=tf.int32) output = tf.constant([], dtype=tf.int32) done = tf.constant(False) while not done: # Let autograph know that the shape of 'output' will change during the loop. tf.autograph.experimental.set_loop_options( shape_invariants=[(output, tf.TensorShape([None]))]) # Determine which targets track has the earliest next step. next_step = codec.max_shift_steps + 1 next_track = -1 for i in range(num_tracks): if (current_offsets[i] == targets_length or targets[i][current_offsets[i]] == 0): # Already reached the end of this targets track. # (Zero is technically a valid shift event but we never actually use it; # it is always padding.) continue if not codec.is_shift_event_index(targets[i][current_offsets[i]]): # The only way we would be at a non-shift event is if we have not yet # reached the first shift event, which means we're at step zero. next_step = 0 next_track = i elif targets[i][current_offsets[i]] < next_step: next_step = targets[i][current_offsets[i]] next_track = i if next_track == -1: # We've already merged all of the target tracks in their entirety. done = tf.constant(True) break if next_step == current_step and next_step > 0: # We don't need to include the shift event itself as it's the same step as # the previous shift. start_offset = current_offsets[next_track] + 1 else: start_offset = current_offsets[next_track] # Merge in events up to but not including the next shift. end_offset = start_offset + 1 while end_offset < targets_length and not codec.is_shift_event_index( targets[next_track][end_offset]): end_offset += 1 output = tf.concat( [output, targets[next_track][start_offset:end_offset]], axis=0) current_step = next_step current_offsets = tf.tensor_scatter_nd_update( current_offsets, indices=[[next_track]], updates=[end_offset]) return output def decode_events( state: DS, tokens: np.ndarray, start_time: int, max_time: Optional[int], codec: event_codec.Codec, decode_event_fn: Callable[[DS, float, event_codec.Event, event_codec.Codec], None], ) -> Tuple[int, int]: """Decode a series of tokens, maintaining a decoding state object. Args: state: Decoding state object; will be modified in-place. tokens: event tokens to convert. start_time: offset start time if decoding in the middle of a sequence. max_time: Events at or beyond this time will be dropped. codec: An event_codec.Codec object that maps indices to Event objects. decode_event_fn: Function that consumes an Event (and the current time) and updates the decoding state. Returns: invalid_events: number of events that could not be decoded. dropped_events: number of events dropped due to max_time restriction. """ invalid_events = 0 dropped_events = 0 cur_steps = 0 cur_time = start_time token_idx = 0 for token_idx, token in enumerate(tokens): try: event = codec.decode_event_index(token) except ValueError: invalid_events += 1 continue if event.type == 'shift': cur_steps += event.value cur_time = start_time + cur_steps / codec.steps_per_second if max_time and cur_time > max_time: dropped_events = len(tokens) - token_idx break else: cur_steps = 0 try: decode_event_fn(state, cur_time, event, codec) except ValueError: invalid_events += 1 logging.info( 'Got invalid event when decoding event %s at time %f. ' 'Invalid event counter now at %d.', event, cur_time, invalid_events, exc_info=True) continue return invalid_events, dropped_events