Yeserumo commited on
Commit
ec22b4d
1 Parent(s): 45d33f7

add encoder

Browse files
encoder/__init__.py ADDED
File without changes
encoder/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (159 Bytes). View file
 
encoder/__pycache__/audio.cpython-37.pyc ADDED
Binary file (3.97 kB). View file
 
encoder/__pycache__/inference.cpython-37.pyc ADDED
Binary file (7.17 kB). View file
 
encoder/__pycache__/model.cpython-37.pyc ADDED
Binary file (4.77 kB). View file
 
encoder/__pycache__/params_data.cpython-37.pyc ADDED
Binary file (466 Bytes). View file
 
encoder/__pycache__/params_model.cpython-37.pyc ADDED
Binary file (346 Bytes). View file
 
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ wav,
60
+ sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+
6
+ class SpeakerBatch:
7
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
+ self.speakers = speakers
9
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
+
11
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+
10
+ _model = None # type: SpeakerEncoder
11
+ _device = None # type: torch.device
12
+
13
+
14
+ def load_model(weights_fpath: Path, device=None):
15
+ """
16
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
17
+ first call to embed_frames() with the default weights file.
18
+
19
+ :param weights_fpath: the path to saved model weights.
20
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
+ If None, will default to your GPU if it"s available, otherwise your CPU.
23
+ """
24
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
25
+ # was saved on. Worth investigating.
26
+ global _model, _device
27
+ if device is None:
28
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ elif isinstance(device, str):
30
+ _device = torch.device(device)
31
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
32
+ checkpoint = torch.load(weights_fpath, _device)
33
+ _model.load_state_dict(checkpoint["model_state"])
34
+ _model.eval()
35
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
+
37
+
38
+ def is_loaded():
39
+ return _model is not None
40
+
41
+
42
+ def embed_frames_batch(frames_batch):
43
+ """
44
+ Computes embeddings for a batch of mel spectrogram.
45
+
46
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
47
+ (batch_size, n_frames, n_channels)
48
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
49
+ """
50
+ if _model is None:
51
+ raise Exception("Model was not loaded. Call load_model() before inference.")
52
+
53
+ frames = torch.from_numpy(frames_batch).to(_device)
54
+ embed = _model.forward(frames).detach().cpu().numpy()
55
+ return embed
56
+
57
+
58
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
59
+ min_pad_coverage=0.75, overlap=0.5):
60
+ """
61
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
62
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
63
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
64
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
65
+ defined in params_data.py.
66
+
67
+ The returned ranges may be indexing further than the length of the waveform. It is
68
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
69
+
70
+ :param n_samples: the number of samples in the waveform
71
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
72
+ utterance
73
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
74
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
75
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
76
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
77
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
78
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
79
+ utterances are entirely disjoint.
80
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
81
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
82
+ utterances.
83
+ """
84
+ assert 0 <= overlap < 1
85
+ assert 0 < min_pad_coverage <= 1
86
+
87
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
88
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
89
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
90
+
91
+ # Compute the slices
92
+ wav_slices, mel_slices = [], []
93
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
94
+ for i in range(0, steps, frame_step):
95
+ mel_range = np.array([i, i + partial_utterance_n_frames])
96
+ wav_range = mel_range * samples_per_frame
97
+ mel_slices.append(slice(*mel_range))
98
+ wav_slices.append(slice(*wav_range))
99
+
100
+ # Evaluate whether extra padding is warranted or not
101
+ last_wav_range = wav_slices[-1]
102
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
103
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
104
+ mel_slices = mel_slices[:-1]
105
+ wav_slices = wav_slices[:-1]
106
+
107
+ return wav_slices, mel_slices
108
+
109
+
110
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
111
+ """
112
+ Computes an embedding for a single utterance.
113
+
114
+ # TODO: handle multiple wavs to benefit from batching on GPU
115
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
116
+ :param using_partials: if True, then the utterance is split in partial utterances of
117
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
118
+ normalized average. If False, the utterance is instead computed from feeding the entire
119
+ spectogram to the network.
120
+ :param return_partials: if True, the partial embeddings will also be returned along with the
121
+ wav slices that correspond to the partial embeddings.
122
+ :param kwargs: additional arguments to compute_partial_splits()
123
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
124
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
125
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
126
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
127
+ instead.
128
+ """
129
+ # Process the entire utterance if not using partials
130
+ if not using_partials:
131
+ frames = audio.wav_to_mel_spectrogram(wav)
132
+ embed = embed_frames_batch(frames[None, ...])[0]
133
+ if return_partials:
134
+ return embed, None, None
135
+ return embed
136
+
137
+ # Compute where to split the utterance into partials and pad if necessary
138
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
139
+ max_wave_length = wave_slices[-1].stop
140
+ if max_wave_length >= len(wav):
141
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
142
+
143
+ # Split the utterance into partials
144
+ frames = audio.wav_to_mel_spectrogram(wav)
145
+ frames_batch = np.array([frames[s] for s in mel_slices])
146
+ partial_embeds = embed_frames_batch(frames_batch)
147
+
148
+ # Compute the utterance embedding from the partial embeddings
149
+ raw_embed = np.mean(partial_embeds, axis=0)
150
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
151
+
152
+ if return_partials:
153
+ return embed, partial_embeds, wave_slices
154
+ return embed
155
+
156
+
157
+ def embed_speaker(wavs, **kwargs):
158
+ raise NotImplemented()
159
+
160
+
161
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
162
+ import matplotlib.pyplot as plt
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ sm = cm.ScalarMappable(cmap=cmap)
175
+ sm.set_clim(*color_range)
176
+
177
+ ax.set_xticks([]), ax.set_yticks([])
178
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from multiprocessing import Pool
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from encoder import audio
10
+ from encoder.config import librispeech_datasets, anglophone_nationalites
11
+ from encoder.params_data import *
12
+
13
+
14
+ _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
+
16
+ class DatasetLog:
17
+ """
18
+ Registers metadata about the dataset in a text file.
19
+ """
20
+ def __init__(self, root, name):
21
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
+ self.sample_data = dict()
23
+
24
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
26
+ self.write_line("-----")
27
+ self._log_params()
28
+
29
+ def _log_params(self):
30
+ from encoder import params_data
31
+ self.write_line("Parameter values:")
32
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
+ value = getattr(params_data, param_name)
34
+ self.write_line("\t%s: %s" % (param_name, value))
35
+ self.write_line("-----")
36
+
37
+ def write_line(self, line):
38
+ self.text_file.write("%s\n" % line)
39
+
40
+ def add_sample(self, **kwargs):
41
+ for param_name, value in kwargs.items():
42
+ if not param_name in self.sample_data:
43
+ self.sample_data[param_name] = []
44
+ self.sample_data[param_name].append(value)
45
+
46
+ def finalize(self):
47
+ self.write_line("Statistics:")
48
+ for param_name, values in self.sample_data.items():
49
+ self.write_line("\t%s:" % param_name)
50
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
+ self.write_line("-----")
53
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
+ self.write_line("Finished on %s" % end_time)
55
+ self.text_file.close()
56
+
57
+
58
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
+ dataset_root = datasets_root.joinpath(dataset_name)
60
+ if not dataset_root.exists():
61
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
+ return None, None
63
+ return dataset_root, DatasetLog(out_dir, dataset_name)
64
+
65
+
66
+ def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ audio_durs = []
90
+ for extension in _AUDIO_EXTENSIONS:
91
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
+ # Check if the target output file already exists
93
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
95
+ if skip_existing and out_fname in existing_fnames:
96
+ continue
97
+
98
+ # Load and preprocess the waveform
99
+ wav = audio.preprocess_wav(in_fpath)
100
+ if len(wav) == 0:
101
+ continue
102
+
103
+ # Create the mel spectrogram, discard those that are too short
104
+ frames = audio.wav_to_mel_spectrogram(wav)
105
+ if len(frames) < partials_n_frames:
106
+ continue
107
+
108
+ out_fpath = speaker_out_dir.joinpath(out_fname)
109
+ np.save(out_fpath, frames)
110
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
+ audio_durs.append(len(wav) / sampling_rate)
112
+
113
+ sources_file.close()
114
+
115
+ return audio_durs
116
+
117
+
118
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
+
121
+ # Process the utterances for each speaker
122
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
+ with Pool(4) as pool:
124
+ tasks = pool.imap(work_fn, speaker_dirs)
125
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
+ for sample_dur in sample_durs:
127
+ logger.add_sample(duration=sample_dur)
128
+
129
+ logger.finalize()
130
+ print("Done preprocessing %s.\n" % dataset_name)
131
+
132
+
133
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
+ for dataset_name in librispeech_datasets["train"]["other"]:
135
+ # Initialize the preprocessing
136
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
+ if not dataset_root:
138
+ return
139
+
140
+ # Preprocess all speakers
141
+ speaker_dirs = list(dataset_root.glob("*"))
142
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
+
144
+
145
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
146
+ # Initialize the preprocessing
147
+ dataset_name = "VoxCeleb1"
148
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
149
+ if not dataset_root:
150
+ return
151
+
152
+ # Get the contents of the meta file
153
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
154
+ metadata = [line.split("\t") for line in metafile][1:]
155
+
156
+ # Select the ID and the nationality, filter out non-anglophone speakers
157
+ nationalities = {line[0]: line[3] for line in metadata}
158
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
159
+ nationality.lower() in anglophone_nationalites]
160
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
161
+ (len(keep_speaker_ids), len(nationalities)))
162
+
163
+ # Get the speaker directories for anglophone speakers only
164
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
165
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
166
+ speaker_dir.name in keep_speaker_ids]
167
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
168
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
169
+
170
+ # Preprocess all speakers
171
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
172
+
173
+
174
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
175
+ # Initialize the preprocessing
176
+ dataset_name = "VoxCeleb2"
177
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
178
+ if not dataset_root:
179
+ return
180
+
181
+ # Get the speaker directories
182
+ # Preprocess all speakers
183
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
184
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
+ from encoder.model import SpeakerEncoder
7
+ from encoder.params_model import *
8
+ from encoder.visualizations import Visualizations
9
+ from utils.profiler import Profiler
10
+
11
+
12
+ def sync(device: torch.device):
13
+ # For correct profiling (cuda operations are async)
14
+ if device.type == "cuda":
15
+ torch.cuda.synchronize(device)
16
+
17
+
18
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
+ no_visdom: bool):
21
+ # Create a dataset and a dataloader
22
+ dataset = SpeakerVerificationDataset(clean_data_root)
23
+ loader = SpeakerVerificationDataLoader(
24
+ dataset,
25
+ speakers_per_batch,
26
+ utterances_per_speaker,
27
+ num_workers=4,
28
+ )
29
+
30
+ # Setup the device on which to run the forward pass and the loss. These can be different,
31
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
+ # hyperparameters) faster on the CPU.
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # FIXME: currently, the gradient is None if loss_device is cuda
35
+ loss_device = torch.device("cpu")
36
+
37
+ # Create the model and the optimizer
38
+ model = SpeakerEncoder(device, loss_device)
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
+ init_step = 1
41
+
42
+ # Configure file path for the model
43
+ model_dir = models_dir / run_id
44
+ model_dir.mkdir(exist_ok=True, parents=True)
45
+ state_fpath = model_dir / "encoder.pt"
46
+
47
+ # Load any existing model
48
+ if not force_restart:
49
+ if state_fpath.exists():
50
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
+ checkpoint = torch.load(state_fpath)
52
+ init_step = checkpoint["step"]
53
+ model.load_state_dict(checkpoint["model_state"])
54
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
55
+ optimizer.param_groups[0]["lr"] = learning_rate_init
56
+ else:
57
+ print("No model \"%s\" found, starting training from scratch." % run_id)
58
+ else:
59
+ print("Starting the training from scratch.")
60
+ model.train()
61
+
62
+ # Initialize the visualization environment
63
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
+ vis.log_dataset(dataset)
65
+ vis.log_params()
66
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
+ vis.log_implementation({"Device": device_name})
68
+
69
+ # Training loop
70
+ profiler = Profiler(summarize_every=10, disabled=False)
71
+ for step, speaker_batch in enumerate(loader, init_step):
72
+ profiler.tick("Blocking, waiting for batch (threaded)")
73
+
74
+ # Forward pass
75
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
76
+ sync(device)
77
+ profiler.tick("Data to %s" % device)
78
+ embeds = model(inputs)
79
+ sync(device)
80
+ profiler.tick("Forward pass")
81
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
+ loss, eer = model.loss(embeds_loss)
83
+ sync(loss_device)
84
+ profiler.tick("Loss")
85
+
86
+ # Backward pass
87
+ model.zero_grad()
88
+ loss.backward()
89
+ profiler.tick("Backward pass")
90
+ model.do_gradient_ops()
91
+ optimizer.step()
92
+ profiler.tick("Parameter update")
93
+
94
+ # Update visualizations
95
+ # learning_rate = optimizer.param_groups[0]["lr"]
96
+ vis.update(loss.item(), eer, step)
97
+
98
+ # Draw projections and save them to the backup folder
99
+ if umap_every != 0 and step % umap_every == 0:
100
+ print("Drawing and saving projections (step %d)" % step)
101
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
102
+ embeds = embeds.detach().cpu().numpy()
103
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
+ vis.save()
105
+
106
+ # Overwrite the latest version of the model
107
+ if save_every != 0 and step % save_every == 0:
108
+ print("Saving the model (step %d)" % step)
109
+ torch.save({
110
+ "step": step + 1,
111
+ "model_state": model.state_dict(),
112
+ "optimizer_state": optimizer.state_dict(),
113
+ }, state_fpath)
114
+
115
+ # Make a backup
116
+ if backup_every != 0 and step % backup_every == 0:
117
+ print("Making a backup (step %d)" % step)
118
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
+ torch.save({
120
+ "step": step + 1,
121
+ "model_state": model.state_dict(),
122
+ "optimizer_state": optimizer.state_dict(),
123
+ }, backup_fpath)
124
+
125
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from time import perf_counter as timer
3
+
4
+ import numpy as np
5
+ import umap
6
+ import visdom
7
+
8
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
+
10
+
11
+ colormap = np.array([
12
+ [76, 255, 0],
13
+ [0, 127, 70],
14
+ [255, 0, 0],
15
+ [255, 217, 38],
16
+ [0, 135, 255],
17
+ [165, 0, 165],
18
+ [255, 167, 255],
19
+ [0, 255, 255],
20
+ [255, 96, 38],
21
+ [142, 76, 0],
22
+ [33, 0, 127],
23
+ [0, 0, 0],
24
+ [183, 183, 183],
25
+ ], dtype=np.float) / 255
26
+
27
+
28
+ class Visualizations:
29
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
+ # Tracking data
31
+ self.last_update_timestamp = timer()
32
+ self.update_every = update_every
33
+ self.step_times = []
34
+ self.losses = []
35
+ self.eers = []
36
+ print("Updating the visualizations every %d steps." % update_every)
37
+
38
+ # If visdom is disabled TODO: use a better paradigm for that
39
+ self.disabled = disabled
40
+ if self.disabled:
41
+ return
42
+
43
+ # Set the environment name
44
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
+ if env_name is None:
46
+ self.env_name = now
47
+ else:
48
+ self.env_name = "%s (%s)" % (env_name, now)
49
+
50
+ # Connect to visdom and open the corresponding window in the browser
51
+ try:
52
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
+ except ConnectionError:
54
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
+ "start it.")
56
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
+
58
+ # Create the windows
59
+ self.loss_win = None
60
+ self.eer_win = None
61
+ # self.lr_win = None
62
+ self.implementation_win = None
63
+ self.projection_win = None
64
+ self.implementation_string = ""
65
+
66
+ def log_params(self):
67
+ if self.disabled:
68
+ return
69
+ from encoder import params_data
70
+ from encoder import params_model
71
+ param_string = "<b>Model parameters</b>:<br>"
72
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
+ value = getattr(params_model, param_name)
74
+ param_string += "\t%s: %s<br>" % (param_name, value)
75
+ param_string += "<b>Data parameters</b>:<br>"
76
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
+ value = getattr(params_data, param_name)
78
+ param_string += "\t%s: %s<br>" % (param_name, value)
79
+ self.vis.text(param_string, opts={"title": "Parameters"})
80
+
81
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
82
+ if self.disabled:
83
+ return
84
+ dataset_string = ""
85
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
+ dataset_string += "\n" + dataset.get_logs()
87
+ dataset_string = dataset_string.replace("\n", "<br>")
88
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
89
+
90
+ def log_implementation(self, params):
91
+ if self.disabled:
92
+ return
93
+ implementation_string = ""
94
+ for param, value in params.items():
95
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
+ implementation_string = implementation_string.replace("\n", "<br>")
97
+ self.implementation_string = implementation_string
98
+ self.implementation_win = self.vis.text(
99
+ implementation_string,
100
+ opts={"title": "Training implementation"}
101
+ )
102
+
103
+ def update(self, loss, eer, step):
104
+ # Update the tracking data
105
+ now = timer()
106
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
107
+ self.last_update_timestamp = now
108
+ self.losses.append(loss)
109
+ self.eers.append(eer)
110
+ print(".", end="")
111
+
112
+ # Update the plots every <update_every> steps
113
+ if step % self.update_every != 0:
114
+ return
115
+ time_string = "Step time: mean: %5dms std: %5dms" % \
116
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
+ if not self.disabled:
120
+ self.loss_win = self.vis.line(
121
+ [np.mean(self.losses)],
122
+ [step],
123
+ win=self.loss_win,
124
+ update="append" if self.loss_win else None,
125
+ opts=dict(
126
+ legend=["Avg. loss"],
127
+ xlabel="Step",
128
+ ylabel="Loss",
129
+ title="Loss",
130
+ )
131
+ )
132
+ self.eer_win = self.vis.line(
133
+ [np.mean(self.eers)],
134
+ [step],
135
+ win=self.eer_win,
136
+ update="append" if self.eer_win else None,
137
+ opts=dict(
138
+ legend=["Avg. EER"],
139
+ xlabel="Step",
140
+ ylabel="EER",
141
+ title="Equal error rate"
142
+ )
143
+ )
144
+ if self.implementation_win is not None:
145
+ self.vis.text(
146
+ self.implementation_string + ("<b>%s</b>" % time_string),
147
+ win=self.implementation_win,
148
+ opts={"title": "Training implementation"},
149
+ )
150
+
151
+ # Reset the tracking
152
+ self.losses.clear()
153
+ self.eers.clear()
154
+ self.step_times.clear()
155
+
156
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
+ import matplotlib.pyplot as plt
158
+
159
+ max_speakers = min(max_speakers, len(colormap))
160
+ embeds = embeds[:max_speakers * utterances_per_speaker]
161
+
162
+ n_speakers = len(embeds) // utterances_per_speaker
163
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
+ colors = [colormap[i] for i in ground_truth]
165
+
166
+ reducer = umap.UMAP()
167
+ projected = reducer.fit_transform(embeds)
168
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
+ plt.gca().set_aspect("equal", "datalim")
170
+ plt.title("UMAP projection (step %d)" % step)
171
+ if not self.disabled:
172
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
+ if out_fpath is not None:
174
+ plt.savefig(out_fpath)
175
+ plt.clf()
176
+
177
+ def save(self):
178
+ if not self.disabled:
179
+ self.vis.save([self.env_name])