File size: 4,014 Bytes
9791162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys,os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
import argparse

from tqdm import tqdm
from functools import partial
from argparse import RawTextHelpFormatter
from multiprocessing.pool import ThreadPool

from speaker.models.lstm import LSTMSpeakerEncoder
from speaker.config import SpeakerEncoderConfig
from speaker.utils.audio import AudioProcessor
from speaker.infer import read_json


def get_spk_wavs(dataset_path, output_path):
    wav_files = []
    os.makedirs(f"./{output_path}", exist_ok=True)
    for spks in os.listdir(dataset_path):
        if os.path.isdir(f"./{dataset_path}/{spks}"):
            os.makedirs(f"./{output_path}/{spks}", exist_ok=True)
            for file in os.listdir(f"./{dataset_path}/{spks}"):
                if file.endswith(".wav"):
                    wav_files.append(f"./{dataset_path}/{spks}/{file}")
        elif spks.endswith(".wav"):
            wav_files.append(f"./{dataset_path}/{spks}")
    return wav_files


def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder):
    waveform = speaker_encoder_ap.load_wav(
        wav_file, sr=speaker_encoder_ap.sample_rate
    )
    spec = speaker_encoder_ap.melspectrogram(waveform)
    spec = torch.from_numpy(spec.T)
    if args.use_cuda:
        spec = spec.cuda()
    spec = spec.unsqueeze(0)
    embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
    embed = embed.squeeze()
    embed_path = wav_file.replace(dataset_path, output_path)
    embed_path = embed_path.replace(".wav", ".spk")
    np.save(embed_path, embed, allow_pickle=False)


def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency):
    bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder)

    with ThreadPool(concurrency) as pool:
        list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="""Compute embedding vectors for each wav file in a dataset.""",
        formatter_class=RawTextHelpFormatter,
    )
    parser.add_argument("dataset_path", type=str, help="Path to dataset waves.")
    parser.add_argument(
        "output_path", type=str, help="path for output speaker/speaker_wavs.npy."
    )
    parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
    parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
    args = parser.parse_args()
    dataset_path = args.dataset_path
    output_path = args.output_path
    thread_count = args.thread_count
    # model
    args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar")
    args.config_path = os.path.join("speaker_pretrain", "config.json")
    # config
    config_dict = read_json(args.config_path)

    # model
    config = SpeakerEncoderConfig(config_dict)
    config.from_dict(config_dict)

    speaker_encoder = LSTMSpeakerEncoder(
        config.model_params["input_dim"],
        config.model_params["proj_dim"],
        config.model_params["lstm_dim"],
        config.model_params["num_lstm_layers"],
    )

    speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)

    # preprocess
    speaker_encoder_ap = AudioProcessor(**config.audio)
    # normalize the input audio level and trim silences
    speaker_encoder_ap.do_sound_norm = True
    speaker_encoder_ap.do_trim_silence = True

    wav_files = get_spk_wavs(dataset_path, output_path)

    if thread_count == 0:
        process_num = os.cpu_count()
    else:
        process_num = thread_count

    extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num)