|
import os |
|
import torch |
|
import numpy as np |
|
import torchaudio |
|
import matplotlib.pyplot as plt |
|
|
|
CACHE = { |
|
"get_vits_phoneme_ids": { |
|
"PAD_LENGTH": 310, |
|
"_pad": "_", |
|
"_punctuation": ';:,.!?¡¿—…"«»“” ', |
|
"_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", |
|
"_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", |
|
"_special": "♪☎☒☝⚠", |
|
} |
|
} |
|
|
|
CACHE["get_vits_phoneme_ids"]["symbols"] = ( |
|
[CACHE["get_vits_phoneme_ids"]["_pad"]] |
|
+ list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) |
|
+ list(CACHE["get_vits_phoneme_ids"]["_letters"]) |
|
+ list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) |
|
+ list(CACHE["get_vits_phoneme_ids"]["_special"]) |
|
) |
|
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { |
|
s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) |
|
} |
|
|
|
|
|
def get_vits_phoneme_ids(config, dl_output, metadata): |
|
pad_token_id = 0 |
|
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] |
|
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] |
|
|
|
assert ( |
|
"phonemes" in metadata.keys() |
|
), "You must provide vits phonemes on using addon get_vits_phoneme_ids" |
|
clean_text = metadata["phonemes"] |
|
sequence = [] |
|
|
|
for symbol in clean_text: |
|
symbol_id = _symbol_to_id[symbol] |
|
sequence += [symbol_id] |
|
|
|
inserted_zero_sequence = [0] * (len(sequence) * 2) |
|
inserted_zero_sequence[1::2] = sequence |
|
inserted_zero_sequence = inserted_zero_sequence + [0] |
|
|
|
def _pad_phonemes(phonemes_list): |
|
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) |
|
|
|
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))} |
|
|
|
|
|
def get_vits_phoneme_ids_no_padding(config, dl_output, metadata): |
|
pad_token_id = 0 |
|
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] |
|
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] |
|
|
|
assert ( |
|
"phonemes" in metadata.keys() |
|
), "You must provide vits phonemes on using addon get_vits_phoneme_ids" |
|
clean_text = metadata["phonemes"] + "⚠" |
|
sequence = [] |
|
|
|
for symbol in clean_text: |
|
if symbol not in _symbol_to_id.keys(): |
|
print("%s is not in the vocabulary. %s" % (symbol, clean_text)) |
|
symbol = "_" |
|
symbol_id = _symbol_to_id[symbol] |
|
sequence += [symbol_id] |
|
|
|
def _pad_phonemes(phonemes_list): |
|
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) |
|
|
|
sequence = sequence[:pad_length] |
|
|
|
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))} |
|
|
|
|
|
def calculate_relative_bandwidth(config, dl_output, metadata): |
|
assert "stft" in dl_output.keys() |
|
|
|
|
|
freq_dimensions = dl_output["stft"].size(-1) |
|
|
|
freq_energy_dist = torch.sum(dl_output["stft"], dim=0) |
|
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) |
|
total_energy = freq_energy_dist[-1] |
|
|
|
percentile_5th = total_energy * 0.05 |
|
percentile_95th = total_energy * 0.95 |
|
|
|
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) |
|
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) |
|
|
|
lower_idx = int((lower_idx / freq_dimensions) * 1000) |
|
higher_idx = int((higher_idx / freq_dimensions) * 1000) |
|
|
|
return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])} |
|
|
|
|
|
def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata): |
|
assert "stft" in dl_output.keys() |
|
linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10)) |
|
|
|
|
|
freq_dimensions = linear_mel_spec.size(-1) |
|
freq_energy_dist = torch.sum(linear_mel_spec, dim=0) |
|
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) |
|
total_energy = freq_energy_dist[-1] |
|
|
|
percentile_5th = total_energy * 0.05 |
|
percentile_95th = total_energy * 0.95 |
|
|
|
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) |
|
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) |
|
|
|
latent_t_size = config["model"]["params"]["latent_t_size"] |
|
latent_f_size = config["model"]["params"]["latent_f_size"] |
|
|
|
lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions))) |
|
higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions))) |
|
|
|
bandwidth_condition = torch.zeros((latent_t_size, latent_f_size)) |
|
bandwidth_condition[:, lower_idx:higher_idx] += 1.0 |
|
|
|
return { |
|
"mel_spec_bandwidth_cond_extra_channel": bandwidth_condition, |
|
"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]), |
|
} |
|
|
|
|
|
def waveform_rs_48k(config, dl_output, metadata): |
|
waveform = dl_output["waveform"] |
|
sampling_rate = dl_output["sampling_rate"] |
|
|
|
if sampling_rate != 48000: |
|
waveform_48k = torchaudio.functional.resample( |
|
waveform, orig_freq=sampling_rate, new_freq=48000 |
|
) |
|
else: |
|
waveform_48k = waveform |
|
|
|
return {"waveform_48k": waveform_48k} |
|
|
|
|
|
def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata): |
|
assert ( |
|
"phoneme" not in metadata.keys() |
|
), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json" |
|
|
|
if "phonemes" in metadata.keys(): |
|
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata) |
|
new_item["text"] = "" |
|
else: |
|
fake_metadata = {"phonemes": ""} |
|
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata) |
|
|
|
return new_item |
|
|
|
|
|
def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata): |
|
if "phoneme" in metadata.keys(): |
|
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata) |
|
new_item["text"] = "" |
|
else: |
|
fake_metadata = {"phoneme": []} |
|
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata) |
|
return new_item |
|
|
|
|
|
def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata): |
|
PAD_LENGTH = 135 |
|
|
|
phonemes_lookup_dict = { |
|
"K": 0, |
|
"IH2": 1, |
|
"NG": 2, |
|
"OW2": 3, |
|
"AH2": 4, |
|
"F": 5, |
|
"AE0": 6, |
|
"IY0": 7, |
|
"SH": 8, |
|
"G": 9, |
|
"W": 10, |
|
"UW1": 11, |
|
"AO2": 12, |
|
"AW2": 13, |
|
"UW0": 14, |
|
"EY2": 15, |
|
"UW2": 16, |
|
"AE2": 17, |
|
"IH0": 18, |
|
"P": 19, |
|
"D": 20, |
|
"ER1": 21, |
|
"AA1": 22, |
|
"EH0": 23, |
|
"UH1": 24, |
|
"N": 25, |
|
"V": 26, |
|
"AY1": 27, |
|
"EY1": 28, |
|
"UH2": 29, |
|
"EH1": 30, |
|
"L": 31, |
|
"AA2": 32, |
|
"R": 33, |
|
"OY1": 34, |
|
"Y": 35, |
|
"ER2": 36, |
|
"S": 37, |
|
"AE1": 38, |
|
"AH1": 39, |
|
"JH": 40, |
|
"ER0": 41, |
|
"EH2": 42, |
|
"IY2": 43, |
|
"OY2": 44, |
|
"AW1": 45, |
|
"IH1": 46, |
|
"IY1": 47, |
|
"OW0": 48, |
|
"AO0": 49, |
|
"AY0": 50, |
|
"EY0": 51, |
|
"AY2": 52, |
|
"UH0": 53, |
|
"M": 54, |
|
"TH": 55, |
|
"T": 56, |
|
"OY0": 57, |
|
"AW0": 58, |
|
"DH": 59, |
|
"Z": 60, |
|
"spn": 61, |
|
"AH0": 62, |
|
"sp": 63, |
|
"AO1": 64, |
|
"OW1": 65, |
|
"ZH": 66, |
|
"B": 67, |
|
"AA0": 68, |
|
"CH": 69, |
|
"HH": 70, |
|
} |
|
pad_token_id = len(phonemes_lookup_dict.keys()) |
|
|
|
assert ( |
|
"phoneme" in metadata.keys() |
|
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" |
|
|
|
phonemes = [ |
|
phonemes_lookup_dict[x] |
|
for x in metadata["phoneme"] |
|
if (x in phonemes_lookup_dict.keys()) |
|
] |
|
|
|
if (len(phonemes) / PAD_LENGTH) > 5: |
|
print( |
|
"Warning: Phonemes length is too long and is truncated too much! %s" |
|
% metadata |
|
) |
|
|
|
phonemes = phonemes[:PAD_LENGTH] |
|
|
|
def _pad_phonemes(phonemes_list): |
|
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) |
|
|
|
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} |
|
|
|
|
|
def extract_phoneme_g2p_en_feature(config, dl_output, metadata): |
|
PAD_LENGTH = 250 |
|
|
|
phonemes_lookup_dict = { |
|
" ": 0, |
|
"AA": 1, |
|
"AE": 2, |
|
"AH": 3, |
|
"AO": 4, |
|
"AW": 5, |
|
"AY": 6, |
|
"B": 7, |
|
"CH": 8, |
|
"D": 9, |
|
"DH": 10, |
|
"EH": 11, |
|
"ER": 12, |
|
"EY": 13, |
|
"F": 14, |
|
"G": 15, |
|
"HH": 16, |
|
"IH": 17, |
|
"IY": 18, |
|
"JH": 19, |
|
"K": 20, |
|
"L": 21, |
|
"M": 22, |
|
"N": 23, |
|
"NG": 24, |
|
"OW": 25, |
|
"OY": 26, |
|
"P": 27, |
|
"R": 28, |
|
"S": 29, |
|
"SH": 30, |
|
"T": 31, |
|
"TH": 32, |
|
"UH": 33, |
|
"UW": 34, |
|
"V": 35, |
|
"W": 36, |
|
"Y": 37, |
|
"Z": 38, |
|
"ZH": 39, |
|
} |
|
pad_token_id = len(phonemes_lookup_dict.keys()) |
|
|
|
assert ( |
|
"phoneme" in metadata.keys() |
|
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" |
|
phonemes = [ |
|
phonemes_lookup_dict[x] |
|
for x in metadata["phoneme"] |
|
if (x in phonemes_lookup_dict.keys()) |
|
] |
|
|
|
if (len(phonemes) / PAD_LENGTH) > 5: |
|
print( |
|
"Warning: Phonemes length is too long and is truncated too much! %s" |
|
% metadata |
|
) |
|
|
|
phonemes = phonemes[:PAD_LENGTH] |
|
|
|
def _pad_phonemes(phonemes_list): |
|
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) |
|
|
|
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} |
|
|
|
|
|
def extract_kaldi_fbank_feature(config, dl_output, metadata): |
|
norm_mean = -4.2677393 |
|
norm_std = 4.5689974 |
|
|
|
waveform = dl_output["waveform"] |
|
sampling_rate = dl_output["sampling_rate"] |
|
log_mel_spec_hifigan = dl_output["log_mel_spec"] |
|
|
|
if sampling_rate != 16000: |
|
waveform_16k = torchaudio.functional.resample( |
|
waveform, orig_freq=sampling_rate, new_freq=16000 |
|
) |
|
else: |
|
waveform_16k = waveform |
|
|
|
waveform_16k = waveform_16k - waveform_16k.mean() |
|
fbank = torchaudio.compliance.kaldi.fbank( |
|
waveform_16k, |
|
htk_compat=True, |
|
sample_frequency=16000, |
|
use_energy=False, |
|
window_type="hanning", |
|
num_mel_bins=128, |
|
dither=0.0, |
|
frame_shift=10, |
|
) |
|
|
|
TARGET_LEN = log_mel_spec_hifigan.size(0) |
|
|
|
|
|
n_frames = fbank.shape[0] |
|
p = TARGET_LEN - n_frames |
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
fbank = m(fbank) |
|
elif p < 0: |
|
fbank = fbank[:TARGET_LEN, :] |
|
|
|
fbank = (fbank - norm_mean) / (norm_std * 2) |
|
|
|
return {"ta_kaldi_fbank": fbank} |
|
|
|
|
|
def extract_kaldi_fbank_feature_32k(config, dl_output, metadata): |
|
norm_mean = -4.2677393 |
|
norm_std = 4.5689974 |
|
|
|
waveform = dl_output["waveform"] |
|
sampling_rate = dl_output["sampling_rate"] |
|
log_mel_spec_hifigan = dl_output["log_mel_spec"] |
|
|
|
if sampling_rate != 32000: |
|
waveform_32k = torchaudio.functional.resample( |
|
waveform, orig_freq=sampling_rate, new_freq=32000 |
|
) |
|
else: |
|
waveform_32k = waveform |
|
|
|
waveform_32k = waveform_32k - waveform_32k.mean() |
|
fbank = torchaudio.compliance.kaldi.fbank( |
|
waveform_32k, |
|
htk_compat=True, |
|
sample_frequency=32000, |
|
use_energy=False, |
|
window_type="hanning", |
|
num_mel_bins=128, |
|
dither=0.0, |
|
frame_shift=10, |
|
) |
|
|
|
TARGET_LEN = log_mel_spec_hifigan.size(0) |
|
|
|
|
|
n_frames = fbank.shape[0] |
|
p = TARGET_LEN - n_frames |
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
fbank = m(fbank) |
|
elif p < 0: |
|
fbank = fbank[:TARGET_LEN, :] |
|
|
|
fbank = (fbank - norm_mean) / (norm_std * 2) |
|
|
|
return {"ta_kaldi_fbank": fbank} |
|
|
|
|
|
|
|
def extract_drum_beat(config, dl_output, metadata): |
|
def visualization(conditional_signal, mel_spectrogram, filename): |
|
import soundfile as sf |
|
|
|
sf.write( |
|
os.path.basename(dl_output["fname"]), |
|
np.array(dl_output["waveform"])[0], |
|
dl_output["sampling_rate"], |
|
) |
|
plt.figure(figsize=(10, 10)) |
|
|
|
plt.subplot(211) |
|
plt.imshow(np.array(conditional_signal).T, aspect="auto") |
|
plt.title("Conditional Signal") |
|
|
|
plt.subplot(212) |
|
plt.imshow(np.array(mel_spectrogram).T, aspect="auto") |
|
plt.title("Mel Spectrogram") |
|
|
|
plt.savefig(filename) |
|
plt.close() |
|
|
|
assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata |
|
|
|
sampling_rate = metadata["sample_rate"] |
|
duration = dl_output["duration"] |
|
|
|
original_segment_length_before_resample = int(sampling_rate * duration) |
|
|
|
random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"]) |
|
|
|
|
|
beat = [ |
|
x - random_start_sample |
|
for x in metadata["beat"] |
|
if ( |
|
x - random_start_sample >= 0 |
|
and x - random_start_sample <= original_segment_length_before_resample |
|
) |
|
] |
|
downbeat = [ |
|
x - random_start_sample |
|
for x in metadata["downbeat"] |
|
if ( |
|
x - random_start_sample >= 0 |
|
and x - random_start_sample <= original_segment_length_before_resample |
|
) |
|
] |
|
|
|
latent_shape = ( |
|
config["model"]["params"]["latent_t_size"], |
|
config["model"]["params"]["latent_f_size"], |
|
) |
|
conditional_signal = torch.zeros(latent_shape) |
|
|
|
|
|
|
|
|
|
for each in beat: |
|
beat_index = int( |
|
(each / original_segment_length_before_resample) * latent_shape[0] |
|
) |
|
beat_index = min(beat_index, conditional_signal.size(0) - 1) |
|
|
|
conditional_signal[beat_index, :] -= 0.5 |
|
|
|
for each in downbeat: |
|
beat_index = int( |
|
(each / original_segment_length_before_resample) * latent_shape[0] |
|
) |
|
beat_index = min(beat_index, conditional_signal.size(0) - 1) |
|
|
|
conditional_signal[beat_index, :] += 1.0 |
|
|
|
|
|
|
|
return {"cond_beat_downbeat": conditional_signal} |
|
|