import csv |
import glob |
import math |
import numbers |
import os |
import random |
import typing |
from contextlib import contextmanager |
from dataclasses import dataclass |
from pathlib import Path |
from typing import Dict |
from typing import List |
import numpy as np |
import torch |
import torchaudio |
from flatten_dict import flatten |
from flatten_dict import unflatten |
@dataclass |
class Info: |
"""Shim for torchaudio.info API changes.""" |
sample_rate: float |
num_frames: int |
@property |
def duration(self) -> float: |
return self.num_frames / self.sample_rate |
def info(audio_path: str): |
"""Shim for torchaudio.info to make 0.7.2 API match 0.8.0. |
Parameters |
---------- |
audio_path : str |
Path to audio file. |
""" |
try: |
info = torchaudio.info(str(audio_path)) |
except: |
info = torchaudio.backend.soundfile_backend.info(str(audio_path)) |
if isinstance(info, tuple): |
signal_info = info[0] |
info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length) |
else: |
info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames) |
return info |
def ensure_tensor( |
x: typing.Union[np.ndarray, torch.Tensor, float, int], |
ndim: int = None, |
batch_size: int = None, |
): |
"""Ensures that the input ``x`` is a tensor of specified |
dimensions and batch size. |
Parameters |
---------- |
x : typing.Union[np.ndarray, torch.Tensor, float, int] |
Data that will become a tensor on its way out. |
ndim : int, optional |
How many dimensions should be in the output, by default None |
batch_size : int, optional |
The batch size of the output, by default None |
Returns |
------- |
torch.Tensor |
Modified version of ``x`` as a tensor. |
""" |
if not torch.is_tensor(x): |
x = torch.as_tensor(x) |
if ndim is not None: |
assert x.ndim <= ndim |
while x.ndim < ndim: |
x = x.unsqueeze(-1) |
if batch_size is not None: |
if x.shape[0] != batch_size: |
shape = list(x.shape) |
shape[0] = batch_size |
x = x.expand(*shape) |
return x |
def _get_value(other): |
from . import AudioSignal |
if isinstance(other, AudioSignal): |
return other.audio_data |
return other |
def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int): |
"""Closest frequency bin given a frequency, number |
of bins, and a sampling rate. |
Parameters |
---------- |
hz : torch.Tensor |
Tensor of frequencies in Hz. |
n_fft : int |
Number of FFT bins. |
sample_rate : int |
Sample rate of audio. |
Returns |
------- |
torch.Tensor |
Closest bins to the data. |
""" |
shape = hz.shape |
hz = hz.flatten() |
freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2) |
hz[hz > sample_rate / 2] = sample_rate / 2 |
closest = (hz[None, :] - freqs[:, None]).abs() |
closest_bins = closest.min(dim=0).indices |
return closest_bins.reshape(*shape) |
def random_state(seed: typing.Union[int, np.random.RandomState]): |
""" |
Turn seed into a np.random.RandomState instance. |
Parameters |
---------- |
seed : typing.Union[int, np.random.RandomState] or None |
If seed is None, return the RandomState singleton used by np.random. |
If seed is an int, return a new RandomState instance seeded with seed. |
If seed is already a RandomState instance, return it. |
Otherwise raise ValueError. |
Returns |
------- |
np.random.RandomState |
Random state object. |
Raises |
------ |
ValueError |
If seed is not valid, an error is thrown. |
""" |
if seed is None or seed is np.random: |
return np.random.mtrand._rand |
elif isinstance(seed, (numbers.Integral, np.integer, int)): |
return np.random.RandomState(seed) |
elif isinstance(seed, np.random.RandomState): |
return seed |
else: |
raise ValueError( |
"%r cannot be used to seed a numpy.random.RandomState" " instance" % seed |
) |
def seed(random_seed, set_cudnn=False): |
""" |
Seeds all random states with the same random seed |
for reproducibility. Seeds ``numpy``, ``random`` and ``torch`` |
random generators. |
For full reproducibility, two further options must be set |
according to the torch documentation: |
https://pytorch.org/docs/stable/notes/randomness.html |
To do this, ``set_cudnn`` must be True. It defaults to |
False, since setting it to True results in a performance |
hit. |
Args: |
random_seed (int): integer corresponding to random seed to |
use. |
set_cudnn (bool): Whether or not to set cudnn into determinstic |
mode and off of benchmark mode. Defaults to False. |
""" |
torch.manual_seed(random_seed) |
np.random.seed(random_seed) |
random.seed(random_seed) |
if set_cudnn: |
torch.backends.cudnn.deterministic = True |
torch.backends.cudnn.benchmark = False |
@contextmanager |
def _close_temp_files(tmpfiles: list): |
"""Utility function for creating a context and closing all temporary files |
once the context is exited. For correct functionality, all temporary file |
handles created inside the context must be appended to the ```tmpfiles``` |
list. |
This function is taken wholesale from Scaper. |
Parameters |
---------- |
tmpfiles : list |
List of temporary file handles |
""" |
def _close(): |
for t in tmpfiles: |
try: |
t.close() |
os.unlink(t.name) |
except: |
pass |
try: |
yield |
except: |
_close() |
raise |
_close() |
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] |
def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): |
"""Finds all audio files in a directory recursively. |
Returns a list. |
Parameters |
---------- |
folder : str |
Folder to look for audio files in, recursively. |
ext : List[str], optional |
Extensions to look for without the ., by default |
``['.wav', '.flac', '.mp3', '.mp4']``. |
""" |
folder = Path(folder) |
if str(folder).endswith(tuple(ext)): |
if "*" in str(folder): |
return glob.glob(str(folder), recursive=("**" in str(folder))) |
else: |
return [folder] |
files = [] |
for x in ext: |
files += folder.glob(f"**/*{x}") |
return files |
def read_sources( |
sources: List[str], |
remove_empty: bool = True, |
relative_path: str = "", |
ext: List[str] = AUDIO_EXTENSIONS, |
): |
"""Reads audio sources that can either be folders |
full of audio files, or CSV files that contain paths |
to audio files. CSV files that adhere to the expected |
format can be generated by |
:py:func:`audiotools.data.preprocess.create_csv`. |
Parameters |
---------- |
sources : List[str] |
List of audio sources to be converted into a |
list of lists of audio files. |
remove_empty : bool, optional |
Whether or not to remove rows with an empty "path" |
from each CSV file, by default True. |
Returns |
------- |
list |
List of lists of rows of CSV files. |
""" |
files = [] |
relative_path = Path(relative_path) |
for source in sources: |
source = str(source) |
_files = [] |
if source.endswith(".csv"): |
with open(source, "r") as f: |
reader = csv.DictReader(f) |
for x in reader: |
if remove_empty and x["path"] == "": |
continue |
if x["path"] != "": |
x["path"] = str(relative_path / x["path"]) |
_files.append(x) |
else: |
for x in find_audio(source, ext=ext): |
x = str(relative_path / x) |
_files.append({"path": x}) |
files.append(sorted(_files, key=lambda x: x["path"])) |
return files |
def choose_from_list_of_lists( |
state: np.random.RandomState, list_of_lists: list, p: float = None |
): |
"""Choose a single item from a list of lists. |
Parameters |
---------- |
state : np.random.RandomState |
Random state to use when choosing an item. |
list_of_lists : list |
A list of lists from which items will be drawn. |
p : float, optional |
Probabilities of each list, by default None |
Returns |
------- |
typing.Any |
An item from the list of lists. |
""" |
source_idx = state.choice(list(range(len(list_of_lists))), p=p) |
item_idx = state.randint(len(list_of_lists[source_idx])) |
return list_of_lists[source_idx][item_idx], source_idx, item_idx |
@contextmanager |
def chdir(newdir: typing.Union[Path, str]): |
""" |
Context manager for switching directories to run a |
function. Useful for when you want to use relative |
paths to different runs. |
Parameters |
---------- |
newdir : typing.Union[Path, str] |
Directory to switch to. |
""" |
curdir = os.getcwd() |
try: |
os.chdir(newdir) |
yield |
finally: |
os.chdir(curdir) |
def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"): |
"""Moves items in a batch (typically generated by a DataLoader as a list |
or a dict) to the specified device. This works even if dictionaries |
are nested. |
Parameters |
---------- |
batch : typing.Union[dict, list, torch.Tensor] |
Batch, typically generated by a dataloader, that will be moved to |
the device. |
device : str, optional |
Device to move batch to, by default "cpu" |
Returns |
------- |
typing.Union[dict, list, torch.Tensor] |
Batch with all values moved to the specified device. |
""" |
if isinstance(batch, dict): |
batch = flatten(batch) |
for key, val in batch.items(): |
try: |
batch[key] = val.to(device) |
except: |
pass |
batch = unflatten(batch) |
elif torch.is_tensor(batch): |
batch = batch.to(device) |
elif isinstance(batch, list): |
for i in range(len(batch)): |
try: |
batch[i] = batch[i].to(device) |
except: |
pass |
return batch |
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): |
"""Samples from a distribution defined by a tuple. The first |
item in the tuple is the distribution type, and the rest of the |
items are arguments to that distribution. The distribution function |
is gotten from the ``np.random.RandomState`` object. |
Parameters |
---------- |
dist_tuple : tuple |
Distribution tuple |
state : np.random.RandomState, optional |
Random state, or seed to use, by default None |
Returns |
------- |
typing.Union[float, int, str] |
Draw from the distribution. |
Examples |
-------- |
Sample from a uniform distribution: |
>>> dist_tuple = ("uniform", 0, 1) |
>>> sample_from_dist(dist_tuple) |
Sample from a constant distribution: |
>>> dist_tuple = ("const", 0) |
>>> sample_from_dist(dist_tuple) |
Sample from a normal distribution: |
>>> dist_tuple = ("normal", 0, 0.5) |
>>> sample_from_dist(dist_tuple) |
""" |
if dist_tuple[0] == "const": |
return dist_tuple[1] |
state = random_state(state) |
dist_fn = getattr(state, dist_tuple[0]) |
return dist_fn(*dist_tuple[1:]) |
def collate(list_of_dicts: list, n_splits: int = None): |
"""Collates a list of dictionaries (e.g. as returned by a |
dataloader) into a dictionary with batched values. This routine |
uses the default torch collate function for everything |
except AudioSignal objects, which are handled by the |
:py:func:`audiotools.core.audio_signal.AudioSignal.batch` |
function. |
This function takes n_splits to enable splitting a batch |
into multiple sub-batches for the purposes of gradient accumulation, |
etc. |
Parameters |
---------- |
list_of_dicts : list |
List of dictionaries to be collated. |
n_splits : int |
Number of splits to make when creating the batches (split into |
sub-batches). Useful for things like gradient accumulation. |
Returns |
------- |
dict |
Dictionary containing batched data. |
""" |
from . import AudioSignal |
batches = [] |
list_len = len(list_of_dicts) |
return_list = False if n_splits is None else True |
n_splits = 1 if n_splits is None else n_splits |
n_items = int(math.ceil(list_len / n_splits)) |
for i in range(0, list_len, n_items): |
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] |
dict_of_lists = { |
k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] |
} |
batch = {} |
for k, v in dict_of_lists.items(): |
if isinstance(v, list): |
if all(isinstance(s, AudioSignal) for s in v): |
batch[k] = AudioSignal.batch(v, pad_signals=True) |
else: |
batch[k] = torch.utils.data._utils.collate.default_collate(v) |
batches.append(unflatten(batch)) |
batches = batches[0] if not return_list else batches |
return batches |
BASE_SIZE = 864 |
def format_figure( |
fig_size: tuple = None, |
title: str = None, |
fig=None, |
format_axes: bool = True, |
format: bool = True, |
font_color: str = "white", |
): |
"""Prettifies the spectrogram and waveform plots. A title |
can be inset into the top right corner, and the axes can be |
inset into the figure, allowing the data to take up the entire |
image. Used in |
- :py:func:`audiotools.core.display.DisplayMixin.specshow` |
- :py:func:`audiotools.core.display.DisplayMixin.waveplot` |
- :py:func:`audiotools.core.display.DisplayMixin.wavespec` |
Parameters |
---------- |
fig_size : tuple, optional |
Size of figure, by default (9, 3) |
title : str, optional |
Title to inset in top right, by default None |
fig : matplotlib.figure.Figure, optional |
Figure object, if None ``plt.gcf()`` will be used, by default None |
format_axes : bool, optional |
Format the axes to be inside the figure, by default True |
format : bool, optional |
This formatting can be skipped entirely by passing ``format=False`` |
to any of the plotting functions that use this formater, by default True |
font_color : str, optional |
Color of font of axes, by default "white" |
""" |
import matplotlib |
import matplotlib.pyplot as plt |
if fig_size is None: |
fig_size = DEFAULT_FIG_SIZE |
if not format: |
return |
if fig is None: |
fig = plt.gcf() |
fig.set_size_inches(*fig_size) |
axs = fig.axes |
pixels = (fig.get_size_inches() * fig.dpi)[0] |
font_scale = pixels / BASE_SIZE |
if format_axes: |
axs = fig.axes |
for ax in axs: |
ymin, _ = ax.get_ylim() |
xmin, _ = ax.get_xlim() |
ticks = ax.get_yticks() |
for t in ticks[2:-1]: |
t = axs[0].annotate( |
f"{(t / 1000):2.1f}k", |
xy=(xmin, t), |
xycoords="data", |
xytext=(5, -5), |
textcoords="offset points", |
ha="left", |
va="top", |
color=font_color, |
fontsize=12 * font_scale, |
alpha=0.75, |
) |
ticks = ax.get_xticks()[2:] |
for t in ticks[:-1]: |
t = axs[0].annotate( |
f"{t:2.1f}s", |
xy=(t, ymin), |
xycoords="data", |
xytext=(5, 5), |
textcoords="offset points", |
ha="center", |
va="bottom", |
color=font_color, |
fontsize=12 * font_scale, |
alpha=0.75, |
) |
ax.margins(0, 0) |
ax.set_axis_off() |
ax.xaxis.set_major_locator(plt.NullLocator()) |
ax.yaxis.set_major_locator(plt.NullLocator()) |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) |
if title is not None: |
t = axs[0].annotate( |
title, |
xy=(1, 1), |
xycoords="axes fraction", |
fontsize=20 * font_scale, |
xytext=(-5, -5), |
textcoords="offset points", |
ha="right", |
va="top", |
color="white", |
) |
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) |
def generate_chord_dataset( |
max_voices: int = 8, |
sample_rate: int = 44100, |
num_items: int = 5, |
duration: float = 1.0, |
min_note: str = "C2", |
max_note: str = "C6", |
output_dir: Path = "chords", |
): |
""" |
Generates a toy multitrack dataset of chords, synthesized from sine waves. |
Parameters |
---------- |
max_voices : int, optional |
Maximum number of voices in a chord, by default 8 |
sample_rate : int, optional |
Sample rate of audio, by default 44100 |
num_items : int, optional |
Number of items to generate, by default 5 |
duration : float, optional |
Duration of each item, by default 1.0 |
min_note : str, optional |
Minimum note in the dataset, by default "C2" |
max_note : str, optional |
Maximum note in the dataset, by default "C6" |
output_dir : Path, optional |
Directory to save the dataset, by default "chords" |
""" |
import librosa |
from . import AudioSignal |
from ..data.preprocess import create_csv |
min_midi = librosa.note_to_midi(min_note) |
max_midi = librosa.note_to_midi(max_note) |
tracks = [] |
for idx in range(num_items): |
track = {} |
num_voices = random.randint(1, max_voices) |
for voice_idx in range(num_voices): |
midinote = random.randint(min_midi, max_midi) |
dur = random.uniform(0.85 * duration, duration) |
sig = AudioSignal.wave( |
frequency=librosa.midi_to_hz(midinote), |
duration=dur, |
sample_rate=sample_rate, |
shape="sine", |
) |
track[f"voice_{voice_idx}"] = sig |
tracks.append(track) |
output_dir = Path(output_dir) |
output_dir.mkdir(exist_ok=True) |
for idx, track in enumerate(tracks): |
track_dir = output_dir / f"track_{idx}" |
track_dir.mkdir(exist_ok=True) |
for voice_name, sig in track.items(): |
sig.write(track_dir / f"{voice_name}.wav") |
all_voices = list(set([k for track in tracks for k in track.keys()])) |
voice_lists = {voice: [] for voice in all_voices} |
for track in tracks: |
for voice_name in all_voices: |
if voice_name in track: |
voice_lists[voice_name].append(track[voice_name].path_to_file) |
else: |
voice_lists[voice_name].append("") |
for voice_name, paths in voice_lists.items(): |
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) |
return output_dir |