|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset |
|
""" |
|
|
|
import argparse |
|
import glob |
|
import os |
|
from shutil import copy |
|
|
|
import h5py |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
import tqdm |
|
import fairseq |
|
from torch import nn |
|
|
|
|
|
def read_audio(fname): |
|
""" Load an audio file and return PCM along with the sample rate """ |
|
|
|
wav, sr = sf.read(fname) |
|
assert sr == 16e3 |
|
|
|
return wav, 16e3 |
|
|
|
|
|
class PretrainedWav2VecModel(nn.Module): |
|
def __init__(self, fname): |
|
super().__init__() |
|
|
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) |
|
model = model[0] |
|
model.eval() |
|
|
|
self.model = model |
|
|
|
def forward(self, x): |
|
with torch.no_grad(): |
|
z = self.model.feature_extractor(x) |
|
if isinstance(z, tuple): |
|
z = z[0] |
|
c = self.model.feature_aggregator(z) |
|
return z, c |
|
|
|
|
|
class EmbeddingWriterConfig(argparse.ArgumentParser): |
|
def __init__(self): |
|
super().__init__("Pre-compute embeddings for flashlight datasets") |
|
|
|
kwargs = {"action": "store", "type": str, "required": True} |
|
|
|
self.add_argument("--input", "-i", help="Input Directory", **kwargs) |
|
self.add_argument("--output", "-o", help="Output Directory", **kwargs) |
|
self.add_argument("--model", help="Path to model checkpoint", **kwargs) |
|
self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs) |
|
self.add_argument( |
|
"--ext", default="wav", required=False, help="Audio file extension" |
|
) |
|
|
|
self.add_argument( |
|
"--no-copy-labels", |
|
action="store_true", |
|
help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.", |
|
) |
|
self.add_argument( |
|
"--use-feat", |
|
action="store_true", |
|
help="Use the feature vector ('z') instead of context vector ('c') for features", |
|
) |
|
self.add_argument("--gpu", help="GPU to use", default=0, type=int) |
|
|
|
|
|
class Prediction: |
|
""" Lightweight wrapper around a fairspeech embedding model """ |
|
|
|
def __init__(self, fname, gpu=0): |
|
self.gpu = gpu |
|
self.model = PretrainedWav2VecModel(fname).cuda(gpu) |
|
|
|
def __call__(self, x): |
|
x = torch.from_numpy(x).float().cuda(self.gpu) |
|
with torch.no_grad(): |
|
z, c = self.model(x.unsqueeze(0)) |
|
|
|
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() |
|
|
|
|
|
class H5Writer: |
|
""" Write features as hdf5 file in flashlight compatible format """ |
|
|
|
def __init__(self, fname): |
|
self.fname = fname |
|
os.makedirs(os.path.dirname(self.fname), exist_ok=True) |
|
|
|
def write(self, data): |
|
channel, T = data.shape |
|
|
|
with h5py.File(self.fname, "w") as out_ds: |
|
data = data.T.flatten() |
|
out_ds["features"] = data |
|
out_ds["info"] = np.array([16e3 // 160, T, channel]) |
|
|
|
|
|
class EmbeddingDatasetWriter(object): |
|
"""Given a model and a flashlight dataset, pre-compute and store embeddings |
|
|
|
Args: |
|
input_root, str : |
|
Path to the flashlight dataset |
|
output_root, str : |
|
Desired output directory. Will be created if non-existent |
|
split, str : |
|
Dataset split |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_root, |
|
output_root, |
|
split, |
|
model_fname, |
|
extension="wav", |
|
gpu=0, |
|
verbose=False, |
|
use_feat=False, |
|
): |
|
|
|
assert os.path.exists(model_fname) |
|
|
|
self.model_fname = model_fname |
|
self.model = Prediction(self.model_fname, gpu) |
|
|
|
self.input_root = input_root |
|
self.output_root = output_root |
|
self.split = split |
|
self.verbose = verbose |
|
self.extension = extension |
|
self.use_feat = use_feat |
|
|
|
assert os.path.exists(self.input_path), "Input path '{}' does not exist".format( |
|
self.input_path |
|
) |
|
|
|
def _progress(self, iterable, **kwargs): |
|
if self.verbose: |
|
return tqdm.tqdm(iterable, **kwargs) |
|
return iterable |
|
|
|
def require_output_path(self, fname=None): |
|
path = self.get_output_path(fname) |
|
os.makedirs(path, exist_ok=True) |
|
|
|
@property |
|
def input_path(self): |
|
return self.get_input_path() |
|
|
|
@property |
|
def output_path(self): |
|
return self.get_output_path() |
|
|
|
def get_input_path(self, fname=None): |
|
if fname is None: |
|
return os.path.join(self.input_root, self.split) |
|
return os.path.join(self.get_input_path(), fname) |
|
|
|
def get_output_path(self, fname=None): |
|
if fname is None: |
|
return os.path.join(self.output_root, self.split) |
|
return os.path.join(self.get_output_path(), fname) |
|
|
|
def copy_labels(self): |
|
self.require_output_path() |
|
|
|
labels = list( |
|
filter( |
|
lambda x: self.extension not in x, glob.glob(self.get_input_path("*")) |
|
) |
|
) |
|
for fname in tqdm.tqdm(labels): |
|
copy(fname, self.output_path) |
|
|
|
@property |
|
def input_fnames(self): |
|
return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension)))) |
|
|
|
def __len__(self): |
|
return len(self.input_fnames) |
|
|
|
def write_features(self): |
|
|
|
paths = self.input_fnames |
|
|
|
fnames_context = map( |
|
lambda x: os.path.join( |
|
self.output_path, x.replace("." + self.extension, ".h5context") |
|
), |
|
map(os.path.basename, paths), |
|
) |
|
|
|
for name, target_fname in self._progress( |
|
zip(paths, fnames_context), total=len(self) |
|
): |
|
wav, sr = read_audio(name) |
|
z, c = self.model(wav) |
|
feat = z if self.use_feat else c |
|
writer = H5Writer(target_fname) |
|
writer.write(feat) |
|
|
|
def __repr__(self): |
|
|
|
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( |
|
n_files=len(self), **self.__dict__ |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = EmbeddingWriterConfig().parse_args() |
|
|
|
for split in args.split: |
|
|
|
writer = EmbeddingDatasetWriter( |
|
input_root=args.input, |
|
output_root=args.output, |
|
split=split, |
|
model_fname=args.model, |
|
gpu=args.gpu, |
|
extension=args.ext, |
|
use_feat=args.use_feat, |
|
) |
|
|
|
print(writer) |
|
writer.require_output_path() |
|
|
|
print("Writing Features...") |
|
writer.write_features() |
|
print("Done.") |
|
|
|
if not args.no_copy_labels: |
|
print("Copying label data...") |
|
writer.copy_labels() |
|
print("Done.") |
|
|