Spaces:
Paused
Paused
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb. | |
# %% auto 0 | |
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter', | |
'readlines'] | |
# %% ../nbs/D. Common dataset utilities.ipynb 1 | |
import os | |
import torch | |
import torchaudio | |
from pathlib import Path | |
import webdataset as wds | |
from contextlib import contextmanager | |
import torch.nn.functional as F | |
# %% ../nbs/D. Common dataset utilities.ipynb 2 | |
def shard_glob(input): | |
if '{' in input: | |
return wds.shardlists.expand_urls(input) | |
if isinstance(input, (Path, str)): | |
path = Path(input) | |
if path.is_dir(): | |
glob = '*.tar.gz' | |
else: | |
glob = path.name | |
path = path.parent | |
input = Path(path).glob(glob) | |
else: | |
raise ArgumentError("input should be either a list or a path with an optional glob specifier") | |
return [str(x) for x in input] | |
# %% ../nbs/D. Common dataset utilities.ipynb 3 | |
class join_datasets(torch.utils.data.IterableDataset): | |
def __init__(self, datasets): | |
self.datasets = datasets | |
def __iter__(self): | |
probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float) | |
its = [iter(ds) for ds in self.datasets] | |
while True: | |
try: | |
yield next(its[torch.multinomial(probs, 1)]) | |
except StopIteration: | |
return | |
def __len__(self): | |
return sum([ds.total_samples for ds in self.datasets]) | |
# %% ../nbs/D. Common dataset utilities.ipynb 5 | |
def resampler(newsr = 24000, key = 'samples_24k'): | |
_last_sr = None | |
tform = None | |
def _resample(samples): | |
for s in samples: | |
sr = s['sample_rate'] | |
if sr != newsr: | |
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr) | |
s[key] = tform(s['samples']) | |
else: | |
s[key] = s['samples'] | |
yield s | |
return _resample | |
# %% ../nbs/D. Common dataset utilities.ipynb 6 | |
def derived_name(input, kind, base="audio", suffix=".gz", dir=None): | |
dir = Path(dir) if dir else Path(input).parent | |
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix)) | |
# %% ../nbs/D. Common dataset utilities.ipynb 7 | |
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None): | |
def deriver(url): | |
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir)) | |
return wds.WebDataset( | |
wds.SimpleShardList([url]) | |
).decode(*decoders) | |
return deriver | |
# %% ../nbs/D. Common dataset utilities.ipynb 8 | |
def merge_in(dataset_fun): | |
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function | |
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes). | |
It requires (and validates) that both datasets have the same ordering of keys so you have | |
to use it before any sample shuffling. Shard shuffling is ok. | |
""" | |
def merge_loop(main_samples): | |
#print("new merge loop:", dataset_fun) | |
merged_samples = None | |
cur_url = None | |
i = None | |
for s in main_samples: | |
url = s['__url__'] | |
if url != cur_url: | |
# this will open a new file when we get the first sample with a new __url__ | |
merged_samples = iter(dataset_fun(url)) | |
cur_url = url | |
try: | |
merge_s = next(merged_samples) | |
except StopIteration: | |
# if the original shard got repeated we won't observe a __url__ change | |
# in this case restart the dataset from the beginning | |
merged_samples = iter(dataset_fun(url)) | |
merge_s = next(merged_samples) | |
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}" | |
news = {} | |
news.update(merge_s) | |
news.update(s) | |
yield news | |
return merge_loop | |
# %% ../nbs/D. Common dataset utilities.ipynb 9 | |
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False): | |
for s in stream: | |
audio, sr = s['audio'] | |
imax = len(s[ikey]) - 1 | |
for i,(ts,te) in enumerate(s[ikey]): | |
samples = audio[0,int(ts*sr):int(te*sr)] | |
if pad_to_seconds is not None: | |
padding = pad_to_seconds*sr-samples.shape[-1] | |
lpad = random.randint(0, padding) if random_shift else 0 | |
samples = F.pad(samples, (lpad, padding-lpad)) | |
subs = {"__key__": s['__key__'] + f"_{i:03d}", | |
"src_key": s['__key__'], | |
"__url__": s['__url__'], | |
"i": i, "imax": imax, | |
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr, | |
"lpad": lpad, "rpad": padding-lpad, | |
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr, | |
"samples": samples, "sample_rate": sr} | |
for k in metakeys: | |
subs[k] = s[k][i] | |
yield subs | |
# %% ../nbs/D. Common dataset utilities.ipynb 10 | |
def vad_dataset(shards, ikey='vad.npy', kind='vad'): | |
return wds.WebDataset(shards).compose( | |
wds.decode(wds.torch_audio), | |
merge_in(derived_dataset(kind)), | |
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio | |
wds.rename(audio="flac;mp3;wav;ogg"), | |
lambda x: split_to_chunks(x, ikey=ikey), | |
) | |
# %% ../nbs/D. Common dataset utilities.ipynb 11 | |
def AtomicTarWriter(name, throwaway=False): | |
tmp = name+".tmp" | |
with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink: | |
yield sink | |
if not throwaway: | |
os.rename(tmp, name) | |
# %% ../nbs/D. Common dataset utilities.ipynb 12 | |
def readlines(fname): | |
with open(fname) as file: | |
return [line.rstrip() for line in file] | |