# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse from collections import defaultdict from itertools import chain from pathlib import Path import numpy as np import torchaudio import torchaudio.sox_effects as ta_sox import yaml from tqdm import tqdm from examples.speech_to_text.data_utils import load_tsv_to_dicts from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder def extract_embedding(audio_path, embedder): wav, sr = torchaudio.load(audio_path) # 2D if sr != embedder.RATE: wav, sr = ta_sox.apply_effects_tensor( wav, sr, [["rate", str(embedder.RATE)]] ) try: emb = embedder([wav[0].cuda().float()]).cpu().numpy() except RuntimeError: emb = None return emb def process(args): print("Fetching data...") raw_manifest_root = Path(args.raw_manifest_root).absolute() samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv")) for s in args.splits] samples = list(chain(*samples)) with open(args.config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f: speaker_to_id = {r.strip(): i for i, r in enumerate(f)} embedder = SpkrEmbedder(args.ckpt).cuda() speaker_to_cnt = defaultdict(float) speaker_to_emb = defaultdict(float) for sample in tqdm(samples, desc="extract emb"): emb = extract_embedding(sample["audio"], embedder) if emb is not None: speaker_to_cnt[sample["speaker"]] += 1 speaker_to_emb[sample["speaker"]] += emb if len(speaker_to_emb) != len(speaker_to_id): missed = set(speaker_to_id) - set(speaker_to_emb.keys()) print( f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}" ) speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float) for speaker in speaker_to_emb: idx = speaker_to_id[speaker] emb = speaker_to_emb[speaker] cnt = speaker_to_cnt[speaker] speaker_emb_mat[idx, :] = emb / cnt speaker_emb_name = "speaker_emb.npy" speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}" np.save(speaker_emb_path, speaker_emb_mat) config["speaker_emb_filename"] = speaker_emb_name with open(args.new_config, "w") as f: yaml.dump(config, f) def main(): parser = argparse.ArgumentParser() parser.add_argument("--raw-manifest-root", "-m", required=True, type=str) parser.add_argument("--splits", "-s", type=str, nargs="+", default=["train"]) parser.add_argument("--config", "-c", required=True, type=str) parser.add_argument("--new-config", "-n", required=True, type=str) parser.add_argument("--ckpt", required=True, type=str, help="speaker embedder checkpoint") args = parser.parse_args() process(args) if __name__ == "__main__": main()