Spaces:
Running
on
Zero
Running
on
Zero
import csv | |
import glob | |
import math | |
import numbers | |
import os | |
import random | |
import typing | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Dict | |
from typing import List | |
import numpy as np | |
import torch | |
import torchaudio | |
from flatten_dict import flatten | |
from flatten_dict import unflatten | |
class Info: | |
"""Shim for torchaudio.info API changes.""" | |
sample_rate: float | |
num_frames: int | |
def duration(self) -> float: | |
return self.num_frames / self.sample_rate | |
def info(audio_path: str): | |
"""Shim for torchaudio.info to make 0.7.2 API match 0.8.0. | |
Parameters | |
---------- | |
audio_path : str | |
Path to audio file. | |
""" | |
# try default backend first, then fallback to soundfile | |
try: | |
info = torchaudio.info(str(audio_path)) | |
except: # pragma: no cover | |
info = torchaudio.backend.soundfile_backend.info(str(audio_path)) | |
if isinstance(info, tuple): # pragma: no cover | |
signal_info = info[0] | |
info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length) | |
else: | |
info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames) | |
return info | |
def ensure_tensor( | |
x: typing.Union[np.ndarray, torch.Tensor, float, int], | |
ndim: int = None, | |
batch_size: int = None, | |
): | |
"""Ensures that the input ``x`` is a tensor of specified | |
dimensions and batch size. | |
Parameters | |
---------- | |
x : typing.Union[np.ndarray, torch.Tensor, float, int] | |
Data that will become a tensor on its way out. | |
ndim : int, optional | |
How many dimensions should be in the output, by default None | |
batch_size : int, optional | |
The batch size of the output, by default None | |
Returns | |
------- | |
torch.Tensor | |
Modified version of ``x`` as a tensor. | |
""" | |
if not torch.is_tensor(x): | |
x = torch.as_tensor(x) | |
if ndim is not None: | |
assert x.ndim <= ndim | |
while x.ndim < ndim: | |
x = x.unsqueeze(-1) | |
if batch_size is not None: | |
if x.shape[0] != batch_size: | |
shape = list(x.shape) | |
shape[0] = batch_size | |
x = x.expand(*shape) | |
return x | |
def _get_value(other): | |
from . import AudioSignal | |
if isinstance(other, AudioSignal): | |
return other.audio_data | |
return other | |
def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int): | |
"""Closest frequency bin given a frequency, number | |
of bins, and a sampling rate. | |
Parameters | |
---------- | |
hz : torch.Tensor | |
Tensor of frequencies in Hz. | |
n_fft : int | |
Number of FFT bins. | |
sample_rate : int | |
Sample rate of audio. | |
Returns | |
------- | |
torch.Tensor | |
Closest bins to the data. | |
""" | |
shape = hz.shape | |
hz = hz.flatten() | |
freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2) | |
hz[hz > sample_rate / 2] = sample_rate / 2 | |
closest = (hz[None, :] - freqs[:, None]).abs() | |
closest_bins = closest.min(dim=0).indices | |
return closest_bins.reshape(*shape) | |
def random_state(seed: typing.Union[int, np.random.RandomState]): | |
""" | |
Turn seed into a np.random.RandomState instance. | |
Parameters | |
---------- | |
seed : typing.Union[int, np.random.RandomState] or None | |
If seed is None, return the RandomState singleton used by np.random. | |
If seed is an int, return a new RandomState instance seeded with seed. | |
If seed is already a RandomState instance, return it. | |
Otherwise raise ValueError. | |
Returns | |
------- | |
np.random.RandomState | |
Random state object. | |
Raises | |
------ | |
ValueError | |
If seed is not valid, an error is thrown. | |
""" | |
if seed is None or seed is np.random: | |
return np.random.mtrand._rand | |
elif isinstance(seed, (numbers.Integral, np.integer, int)): | |
return np.random.RandomState(seed) | |
elif isinstance(seed, np.random.RandomState): | |
return seed | |
else: | |
raise ValueError( | |
"%r cannot be used to seed a numpy.random.RandomState" " instance" % seed | |
) | |
def seed(random_seed, set_cudnn=False): | |
""" | |
Seeds all random states with the same random seed | |
for reproducibility. Seeds ``numpy``, ``random`` and ``torch`` | |
random generators. | |
For full reproducibility, two further options must be set | |
according to the torch documentation: | |
https://pytorch.org/docs/stable/notes/randomness.html | |
To do this, ``set_cudnn`` must be True. It defaults to | |
False, since setting it to True results in a performance | |
hit. | |
Args: | |
random_seed (int): integer corresponding to random seed to | |
use. | |
set_cudnn (bool): Whether or not to set cudnn into determinstic | |
mode and off of benchmark mode. Defaults to False. | |
""" | |
torch.manual_seed(random_seed) | |
np.random.seed(random_seed) | |
random.seed(random_seed) | |
if set_cudnn: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def _close_temp_files(tmpfiles: list): | |
"""Utility function for creating a context and closing all temporary files | |
once the context is exited. For correct functionality, all temporary file | |
handles created inside the context must be appended to the ```tmpfiles``` | |
list. | |
This function is taken wholesale from Scaper. | |
Parameters | |
---------- | |
tmpfiles : list | |
List of temporary file handles | |
""" | |
def _close(): | |
for t in tmpfiles: | |
try: | |
t.close() | |
os.unlink(t.name) | |
except: | |
pass | |
try: | |
yield | |
except: # pragma: no cover | |
_close() | |
raise | |
_close() | |
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] | |
def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): | |
"""Finds all audio files in a directory recursively. | |
Returns a list. | |
Parameters | |
---------- | |
folder : str | |
Folder to look for audio files in, recursively. | |
ext : List[str], optional | |
Extensions to look for without the ., by default | |
``['.wav', '.flac', '.mp3', '.mp4']``. | |
""" | |
folder = Path(folder) | |
# Take care of case where user has passed in an audio file directly | |
# into one of the calling functions. | |
if str(folder).endswith(tuple(ext)): | |
# if, however, there's a glob in the path, we need to | |
# return the glob, not the file. | |
if "*" in str(folder): | |
return glob.glob(str(folder), recursive=("**" in str(folder))) | |
else: | |
return [folder] | |
files = [] | |
for x in ext: | |
files += folder.glob(f"**/*{x}") | |
return files | |
def read_sources( | |
sources: List[str], | |
remove_empty: bool = True, | |
relative_path: str = "", | |
ext: List[str] = AUDIO_EXTENSIONS, | |
): | |
"""Reads audio sources that can either be folders | |
full of audio files, or CSV files that contain paths | |
to audio files. CSV files that adhere to the expected | |
format can be generated by | |
:py:func:`audiotools.data.preprocess.create_csv`. | |
Parameters | |
---------- | |
sources : List[str] | |
List of audio sources to be converted into a | |
list of lists of audio files. | |
remove_empty : bool, optional | |
Whether or not to remove rows with an empty "path" | |
from each CSV file, by default True. | |
Returns | |
------- | |
list | |
List of lists of rows of CSV files. | |
""" | |
files = [] | |
relative_path = Path(relative_path) | |
for source in sources: | |
source = str(source) | |
_files = [] | |
if source.endswith(".csv"): | |
with open(source, "r") as f: | |
reader = csv.DictReader(f) | |
for x in reader: | |
if remove_empty and x["path"] == "": | |
continue | |
if x["path"] != "": | |
x["path"] = str(relative_path / x["path"]) | |
_files.append(x) | |
else: | |
for x in find_audio(source, ext=ext): | |
x = str(relative_path / x) | |
_files.append({"path": x}) | |
files.append(sorted(_files, key=lambda x: x["path"])) | |
return files | |
def choose_from_list_of_lists( | |
state: np.random.RandomState, list_of_lists: list, p: float = None | |
): | |
"""Choose a single item from a list of lists. | |
Parameters | |
---------- | |
state : np.random.RandomState | |
Random state to use when choosing an item. | |
list_of_lists : list | |
A list of lists from which items will be drawn. | |
p : float, optional | |
Probabilities of each list, by default None | |
Returns | |
------- | |
typing.Any | |
An item from the list of lists. | |
""" | |
source_idx = state.choice(list(range(len(list_of_lists))), p=p) | |
item_idx = state.randint(len(list_of_lists[source_idx])) | |
return list_of_lists[source_idx][item_idx], source_idx, item_idx | |
def chdir(newdir: typing.Union[Path, str]): | |
""" | |
Context manager for switching directories to run a | |
function. Useful for when you want to use relative | |
paths to different runs. | |
Parameters | |
---------- | |
newdir : typing.Union[Path, str] | |
Directory to switch to. | |
""" | |
curdir = os.getcwd() | |
try: | |
os.chdir(newdir) | |
yield | |
finally: | |
os.chdir(curdir) | |
def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"): | |
"""Moves items in a batch (typically generated by a DataLoader as a list | |
or a dict) to the specified device. This works even if dictionaries | |
are nested. | |
Parameters | |
---------- | |
batch : typing.Union[dict, list, torch.Tensor] | |
Batch, typically generated by a dataloader, that will be moved to | |
the device. | |
device : str, optional | |
Device to move batch to, by default "cpu" | |
Returns | |
------- | |
typing.Union[dict, list, torch.Tensor] | |
Batch with all values moved to the specified device. | |
""" | |
if isinstance(batch, dict): | |
batch = flatten(batch) | |
for key, val in batch.items(): | |
try: | |
batch[key] = val.to(device) | |
except: | |
pass | |
batch = unflatten(batch) | |
elif torch.is_tensor(batch): | |
batch = batch.to(device) | |
elif isinstance(batch, list): | |
for i in range(len(batch)): | |
try: | |
batch[i] = batch[i].to(device) | |
except: | |
pass | |
return batch | |
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): | |
"""Samples from a distribution defined by a tuple. The first | |
item in the tuple is the distribution type, and the rest of the | |
items are arguments to that distribution. The distribution function | |
is gotten from the ``np.random.RandomState`` object. | |
Parameters | |
---------- | |
dist_tuple : tuple | |
Distribution tuple | |
state : np.random.RandomState, optional | |
Random state, or seed to use, by default None | |
Returns | |
------- | |
typing.Union[float, int, str] | |
Draw from the distribution. | |
Examples | |
-------- | |
Sample from a uniform distribution: | |
>>> dist_tuple = ("uniform", 0, 1) | |
>>> sample_from_dist(dist_tuple) | |
Sample from a constant distribution: | |
>>> dist_tuple = ("const", 0) | |
>>> sample_from_dist(dist_tuple) | |
Sample from a normal distribution: | |
>>> dist_tuple = ("normal", 0, 0.5) | |
>>> sample_from_dist(dist_tuple) | |
""" | |
if dist_tuple[0] == "const": | |
return dist_tuple[1] | |
state = random_state(state) | |
dist_fn = getattr(state, dist_tuple[0]) | |
return dist_fn(*dist_tuple[1:]) | |
def collate(list_of_dicts: list, n_splits: int = None): | |
"""Collates a list of dictionaries (e.g. as returned by a | |
dataloader) into a dictionary with batched values. This routine | |
uses the default torch collate function for everything | |
except AudioSignal objects, which are handled by the | |
:py:func:`audiotools.core.audio_signal.AudioSignal.batch` | |
function. | |
This function takes n_splits to enable splitting a batch | |
into multiple sub-batches for the purposes of gradient accumulation, | |
etc. | |
Parameters | |
---------- | |
list_of_dicts : list | |
List of dictionaries to be collated. | |
n_splits : int | |
Number of splits to make when creating the batches (split into | |
sub-batches). Useful for things like gradient accumulation. | |
Returns | |
------- | |
dict | |
Dictionary containing batched data. | |
""" | |
from . import AudioSignal | |
batches = [] | |
list_len = len(list_of_dicts) | |
return_list = False if n_splits is None else True | |
n_splits = 1 if n_splits is None else n_splits | |
n_items = int(math.ceil(list_len / n_splits)) | |
for i in range(0, list_len, n_items): | |
# Flatten the dictionaries to avoid recursion. | |
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] | |
dict_of_lists = { | |
k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] | |
} | |
batch = {} | |
for k, v in dict_of_lists.items(): | |
if isinstance(v, list): | |
if all(isinstance(s, AudioSignal) for s in v): | |
batch[k] = AudioSignal.batch(v, pad_signals=True) | |
else: | |
# Borrow the default collate fn from torch. | |
batch[k] = torch.utils.data._utils.collate.default_collate(v) | |
batches.append(unflatten(batch)) | |
batches = batches[0] if not return_list else batches | |
return batches | |
BASE_SIZE = 864 | |
DEFAULT_FIG_SIZE = (9, 3) | |
def format_figure( | |
fig_size: tuple = None, | |
title: str = None, | |
fig=None, | |
format_axes: bool = True, | |
format: bool = True, | |
font_color: str = "white", | |
): | |
"""Prettifies the spectrogram and waveform plots. A title | |
can be inset into the top right corner, and the axes can be | |
inset into the figure, allowing the data to take up the entire | |
image. Used in | |
- :py:func:`audiotools.core.display.DisplayMixin.specshow` | |
- :py:func:`audiotools.core.display.DisplayMixin.waveplot` | |
- :py:func:`audiotools.core.display.DisplayMixin.wavespec` | |
Parameters | |
---------- | |
fig_size : tuple, optional | |
Size of figure, by default (9, 3) | |
title : str, optional | |
Title to inset in top right, by default None | |
fig : matplotlib.figure.Figure, optional | |
Figure object, if None ``plt.gcf()`` will be used, by default None | |
format_axes : bool, optional | |
Format the axes to be inside the figure, by default True | |
format : bool, optional | |
This formatting can be skipped entirely by passing ``format=False`` | |
to any of the plotting functions that use this formater, by default True | |
font_color : str, optional | |
Color of font of axes, by default "white" | |
""" | |
import matplotlib | |
import matplotlib.pyplot as plt | |
if fig_size is None: | |
fig_size = DEFAULT_FIG_SIZE | |
if not format: | |
return | |
if fig is None: | |
fig = plt.gcf() | |
fig.set_size_inches(*fig_size) | |
axs = fig.axes | |
pixels = (fig.get_size_inches() * fig.dpi)[0] | |
font_scale = pixels / BASE_SIZE | |
if format_axes: | |
axs = fig.axes | |
for ax in axs: | |
ymin, _ = ax.get_ylim() | |
xmin, _ = ax.get_xlim() | |
ticks = ax.get_yticks() | |
for t in ticks[2:-1]: | |
t = axs[0].annotate( | |
f"{(t / 1000):2.1f}k", | |
xy=(xmin, t), | |
xycoords="data", | |
xytext=(5, -5), | |
textcoords="offset points", | |
ha="left", | |
va="top", | |
color=font_color, | |
fontsize=12 * font_scale, | |
alpha=0.75, | |
) | |
ticks = ax.get_xticks()[2:] | |
for t in ticks[:-1]: | |
t = axs[0].annotate( | |
f"{t:2.1f}s", | |
xy=(t, ymin), | |
xycoords="data", | |
xytext=(5, 5), | |
textcoords="offset points", | |
ha="center", | |
va="bottom", | |
color=font_color, | |
fontsize=12 * font_scale, | |
alpha=0.75, | |
) | |
ax.margins(0, 0) | |
ax.set_axis_off() | |
ax.xaxis.set_major_locator(plt.NullLocator()) | |
ax.yaxis.set_major_locator(plt.NullLocator()) | |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) | |
if title is not None: | |
t = axs[0].annotate( | |
title, | |
xy=(1, 1), | |
xycoords="axes fraction", | |
fontsize=20 * font_scale, | |
xytext=(-5, -5), | |
textcoords="offset points", | |
ha="right", | |
va="top", | |
color="white", | |
) | |
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) | |
def generate_chord_dataset( | |
max_voices: int = 8, | |
sample_rate: int = 44100, | |
num_items: int = 5, | |
duration: float = 1.0, | |
min_note: str = "C2", | |
max_note: str = "C6", | |
output_dir: Path = "chords", | |
): | |
""" | |
Generates a toy multitrack dataset of chords, synthesized from sine waves. | |
Parameters | |
---------- | |
max_voices : int, optional | |
Maximum number of voices in a chord, by default 8 | |
sample_rate : int, optional | |
Sample rate of audio, by default 44100 | |
num_items : int, optional | |
Number of items to generate, by default 5 | |
duration : float, optional | |
Duration of each item, by default 1.0 | |
min_note : str, optional | |
Minimum note in the dataset, by default "C2" | |
max_note : str, optional | |
Maximum note in the dataset, by default "C6" | |
output_dir : Path, optional | |
Directory to save the dataset, by default "chords" | |
""" | |
import librosa | |
from . import AudioSignal | |
from ..data.preprocess import create_csv | |
min_midi = librosa.note_to_midi(min_note) | |
max_midi = librosa.note_to_midi(max_note) | |
tracks = [] | |
for idx in range(num_items): | |
track = {} | |
# figure out how many voices to put in this track | |
num_voices = random.randint(1, max_voices) | |
for voice_idx in range(num_voices): | |
# choose some random params | |
midinote = random.randint(min_midi, max_midi) | |
dur = random.uniform(0.85 * duration, duration) | |
sig = AudioSignal.wave( | |
frequency=librosa.midi_to_hz(midinote), | |
duration=dur, | |
sample_rate=sample_rate, | |
shape="sine", | |
) | |
track[f"voice_{voice_idx}"] = sig | |
tracks.append(track) | |
# save the tracks to disk | |
output_dir = Path(output_dir) | |
output_dir.mkdir(exist_ok=True) | |
for idx, track in enumerate(tracks): | |
track_dir = output_dir / f"track_{idx}" | |
track_dir.mkdir(exist_ok=True) | |
for voice_name, sig in track.items(): | |
sig.write(track_dir / f"{voice_name}.wav") | |
all_voices = list(set([k for track in tracks for k in track.keys()])) | |
voice_lists = {voice: [] for voice in all_voices} | |
for track in tracks: | |
for voice_name in all_voices: | |
if voice_name in track: | |
voice_lists[voice_name].append(track[voice_name].path_to_file) | |
else: | |
voice_lists[voice_name].append("") | |
for voice_name, paths in voice_lists.items(): | |
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) | |
return output_dir | |