File size: 11,520 Bytes
e831f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import statistics

import torch
from torch.utils.data import Dataset
from tqdm import tqdm

from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio


class FastSpeechDataset(Dataset):

    def __init__(self,
                 path_to_transcript_dict,
                 acoustic_checkpoint_path,
                 cache_dir,
                 lang,
                 loading_processes=40,
                 min_len_in_seconds=1,
                 max_len_in_seconds=20,
                 cut_silence=False,
                 reduction_factor=1,
                 device=torch.device("cpu"),
                 rebuild_cache=False,
                 ctc_selection=True,
                 save_imgs=False):
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
            if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
                AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
                               cache_dir=cache_dir,
                               lang=lang,
                               loading_processes=loading_processes,
                               min_len_in_seconds=min_len_in_seconds,
                               max_len_in_seconds=max_len_in_seconds,
                               cut_silences=cut_silence,
                               rebuild_cache=rebuild_cache,
                               device=device)
            datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
            # we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
            if not isinstance(datapoints, tuple):  # check for backwards compatibility
                print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
                AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
                               cache_dir=cache_dir,
                               lang=lang,
                               loading_processes=loading_processes,
                               min_len_in_seconds=min_len_in_seconds,
                               max_len_in_seconds=max_len_in_seconds,
                               cut_silences=cut_silence,
                               rebuild_cache=True)
                datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
            dataset = datapoints[0]
            norm_waves = datapoints[1]

            # build cache
            print("... building dataset cache ...")
            self.datapoints = list()
            self.ctc_losses = list()

            acoustic_model = Aligner()
            acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])

            # ==========================================
            # actual creation of datapoints starts here
            # ==========================================

            acoustic_model = acoustic_model.to(device)
            dio = Dio(reduction_factor=reduction_factor, fs=16000)
            energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
            dc = DurationCalculator(reduction_factor=reduction_factor)
            vis_dir = os.path.join(cache_dir, "duration_vis")
            os.makedirs(vis_dir, exist_ok=True)
            pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)

            for index in tqdm(range(len(dataset))):
                norm_wave = norm_waves[index]
                norm_wave_length = torch.LongTensor([len(norm_wave)])

                if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
                    continue

                text = dataset[index][0]
                melspec = dataset[index][2]
                melspec_length = dataset[index][3]

                alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
                                                                    tokens=text.to(device),
                                                                    save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
                                                                    return_ctc=True)

                cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()

                last_vec = None
                for phoneme_index, vec in enumerate(text):
                    if last_vec is not None:
                        if last_vec.numpy().tolist() == vec.numpy().tolist():
                            # we found a case of repeating phonemes!
                            # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
                            dur_1 = cached_duration[phoneme_index - 1]
                            dur_2 = cached_duration[phoneme_index]
                            total_dur = dur_1 + dur_2
                            new_dur_1 = int((total_dur / 5) * 3)
                            new_dur_2 = total_dur - new_dur_1
                            cached_duration[phoneme_index - 1] = new_dur_1
                            cached_duration[phoneme_index] = new_dur_2
                    last_vec = vec

                cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
                                            input_waves_lengths=norm_wave_length,
                                            feats_lengths=melspec_length,
                                            durations=cached_duration.unsqueeze(0),
                                            durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()

                cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
                                   input_waves_lengths=norm_wave_length,
                                   feats_lengths=melspec_length,
                                   durations=cached_duration.unsqueeze(0),
                                   durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()

                try:
                    prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
                except RuntimeError:
                    # if there is an audio without any voiced segments whatsoever we have to skip it.
                    continue

                self.datapoints.append([dataset[index][0],
                                        dataset[index][1],
                                        dataset[index][2],
                                        dataset[index][3],
                                        cached_duration.cpu(),
                                        cached_energy,
                                        cached_pitch,
                                        prosodic_condition])
                self.ctc_losses.append(ctc_loss)

            # =============================
            # done with datapoint creation
            # =============================

            if ctc_selection:
                # now we can filter out some bad datapoints based on the CTC scores we collected
                mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
                std_dev = statistics.stdev(self.ctc_losses)
                threshold = mean_ctc + std_dev
                for index in range(len(self.ctc_losses), 0, -1):
                    if self.ctc_losses[index - 1] > threshold:
                        self.datapoints.pop(index - 1)
                        print(
                            f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")

            # save to cache
            if len(self.datapoints) > 0:
                torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
            else:
                import sys
                print("No datapoints were prepared! Exiting...")
                sys.exit()
        else:
            # just load the datapoints from cache
            self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')

        self.cache_dir = cache_dir
        self.language_id = get_language_id(lang)
        print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")

    def __getitem__(self, index):
        return self.datapoints[index][0], \
               self.datapoints[index][1], \
               self.datapoints[index][2], \
               self.datapoints[index][3], \
               self.datapoints[index][4], \
               self.datapoints[index][5], \
               self.datapoints[index][6], \
               self.datapoints[index][7], \
               self.language_id

    def __len__(self):
        return len(self.datapoints)

    def remove_samples(self, list_of_samples_to_remove):
        for remove_id in sorted(list_of_samples_to_remove, reverse=True):
            self.datapoints.pop(remove_id)
        torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
        print("Dataset updated!")

    def fix_repeating_phones(self):
        """
        The viterbi decoding of the durations cannot
        handle repetitions. This is now solved heuristically,
        but if you have a cache from before March 2022,
        use this method to postprocess those cases.
        """
        for datapoint_index in tqdm(list(range(len(self.datapoints)))):
            last_vec = None
            for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
                if last_vec is not None:
                    if last_vec.numpy().tolist() == vec.numpy().tolist():
                        # we found a case of repeating phonemes!
                        # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
                        dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
                        dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
                        total_dur = dur_1 + dur_2
                        new_dur_1 = int((total_dur / 5) * 3)
                        new_dur_2 = total_dur - new_dur_1
                        self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
                        self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
                        print("fix applied")
                last_vec = vec
        torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
        print("Dataset updated!")