|
|
|
|
|
|
|
|
|
|
|
import os |
|
from pathlib import Path |
|
from typing import Optional, List, Dict |
|
import zipfile |
|
import tempfile |
|
from dataclasses import dataclass |
|
from itertools import groupby |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from examples.speech_to_text.data_utils import load_tsv_to_dicts |
|
from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale |
|
|
|
|
|
def trim_or_pad_to_target_length( |
|
data_1d_or_2d: np.ndarray, target_length: int |
|
) -> np.ndarray: |
|
assert len(data_1d_or_2d.shape) in {1, 2} |
|
delta = data_1d_or_2d.shape[0] - target_length |
|
if delta >= 0: |
|
data_1d_or_2d = data_1d_or_2d[: target_length] |
|
else: |
|
if len(data_1d_or_2d.shape) == 1: |
|
data_1d_or_2d = np.concatenate( |
|
[data_1d_or_2d, np.zeros(-delta)], axis=0 |
|
) |
|
else: |
|
data_1d_or_2d = np.concatenate( |
|
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))], |
|
axis=0 |
|
) |
|
return data_1d_or_2d |
|
|
|
|
|
def extract_logmel_spectrogram( |
|
waveform: torch.Tensor, sample_rate: int, |
|
output_path: Optional[Path] = None, win_length: int = 1024, |
|
hop_length: int = 256, n_fft: int = 1024, |
|
win_fn: callable = torch.hann_window, n_mels: int = 80, |
|
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5, |
|
overwrite: bool = False, target_length: Optional[int] = None |
|
): |
|
if output_path is not None and output_path.is_file() and not overwrite: |
|
return |
|
|
|
spectrogram_transform = TTSSpectrogram( |
|
n_fft=n_fft, win_length=win_length, hop_length=hop_length, |
|
window_fn=win_fn |
|
) |
|
mel_scale_transform = TTSMelScale( |
|
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, |
|
n_stft=n_fft // 2 + 1 |
|
) |
|
spectrogram = spectrogram_transform(waveform) |
|
mel_spec = mel_scale_transform(spectrogram) |
|
logmel_spec = torch.clamp(mel_spec, min=eps).log() |
|
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1 |
|
logmel_spec = logmel_spec.squeeze().t() |
|
if target_length is not None: |
|
trim_or_pad_to_target_length(logmel_spec, target_length) |
|
|
|
if output_path is not None: |
|
np.save(output_path.as_posix(), logmel_spec) |
|
else: |
|
return logmel_spec |
|
|
|
|
|
def extract_pitch( |
|
waveform: torch.Tensor, sample_rate: int, |
|
output_path: Optional[Path] = None, hop_length: int = 256, |
|
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None |
|
): |
|
if output_path is not None and output_path.is_file(): |
|
return |
|
|
|
try: |
|
import pyworld |
|
except ImportError: |
|
raise ImportError("Please install PyWORLD: pip install pyworld") |
|
|
|
_waveform = waveform.squeeze(0).double().numpy() |
|
pitch, t = pyworld.dio( |
|
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000 |
|
) |
|
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate) |
|
|
|
if phoneme_durations is not None: |
|
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations)) |
|
try: |
|
from scipy.interpolate import interp1d |
|
except ImportError: |
|
raise ImportError("Please install SciPy: pip install scipy") |
|
nonzero_ids = np.where(pitch != 0)[0] |
|
interp_fn = interp1d( |
|
nonzero_ids, |
|
pitch[nonzero_ids], |
|
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), |
|
bounds_error=False, |
|
) |
|
pitch = interp_fn(np.arange(0, len(pitch))) |
|
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) |
|
pitch = np.array( |
|
[ |
|
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]]) |
|
for i in range(1, len(d_cumsum)) |
|
] |
|
) |
|
assert len(pitch) == len(phoneme_durations) |
|
|
|
if log_scale: |
|
pitch = np.log(pitch + 1) |
|
|
|
if output_path is not None: |
|
np.save(output_path.as_posix(), pitch) |
|
else: |
|
return pitch |
|
|
|
|
|
def extract_energy( |
|
waveform: torch.Tensor, output_path: Optional[Path] = None, |
|
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True, |
|
phoneme_durations: Optional[List[int]] = None |
|
): |
|
if output_path is not None and output_path.is_file(): |
|
return |
|
|
|
assert len(waveform.shape) == 2 and waveform.shape[0] == 1 |
|
waveform = waveform.view(1, 1, waveform.shape[1]) |
|
waveform = F.pad( |
|
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0], |
|
mode="reflect" |
|
) |
|
waveform = waveform.squeeze(1) |
|
|
|
fourier_basis = np.fft.fft(np.eye(n_fft)) |
|
cutoff = int((n_fft / 2 + 1)) |
|
fourier_basis = np.vstack( |
|
[np.real(fourier_basis[:cutoff, :]), |
|
np.imag(fourier_basis[:cutoff, :])] |
|
) |
|
|
|
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) |
|
forward_transform = F.conv1d( |
|
waveform, forward_basis, stride=hop_length, padding=0 |
|
) |
|
|
|
real_part = forward_transform[:, :cutoff, :] |
|
imag_part = forward_transform[:, cutoff:, :] |
|
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) |
|
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy() |
|
|
|
if phoneme_durations is not None: |
|
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations)) |
|
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) |
|
energy = np.array( |
|
[ |
|
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]]) |
|
for i in range(1, len(d_cumsum)) |
|
] |
|
) |
|
assert len(energy) == len(phoneme_durations) |
|
|
|
if log_scale: |
|
energy = np.log(energy + 1) |
|
|
|
if output_path is not None: |
|
np.save(output_path.as_posix(), energy) |
|
else: |
|
return energy |
|
|
|
|
|
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None): |
|
mean_x, mean_x2, n_frames = None, None, 0 |
|
feature_paths = feature_root.glob("*.npy") |
|
for p in tqdm(feature_paths): |
|
with open(p, 'rb') as f: |
|
frames = np.load(f).squeeze() |
|
|
|
n_frames += frames.shape[0] |
|
|
|
cur_mean_x = frames.sum(axis=0) |
|
if mean_x is None: |
|
mean_x = cur_mean_x |
|
else: |
|
mean_x += cur_mean_x |
|
|
|
cur_mean_x2 = (frames ** 2).sum(axis=0) |
|
if mean_x2 is None: |
|
mean_x2 = cur_mean_x2 |
|
else: |
|
mean_x2 += cur_mean_x2 |
|
|
|
mean_x /= n_frames |
|
mean_x2 /= n_frames |
|
var_x = mean_x2 - mean_x ** 2 |
|
std_x = np.sqrt(np.maximum(var_x, 1e-10)) |
|
|
|
if output_path is not None: |
|
with open(output_path, 'wb') as f: |
|
np.savez(f, mean=mean_x, std=std_x) |
|
else: |
|
return {"mean": mean_x, "std": std_x} |
|
|
|
|
|
def ipa_phonemize(text, lang="en-us", use_g2p=False): |
|
if use_g2p: |
|
assert lang == "en-us", "g2pE phonemizer only works for en-us" |
|
try: |
|
from g2p_en import G2p |
|
g2p = G2p() |
|
return " ".join("|" if p == " " else p for p in g2p(text)) |
|
except ImportError: |
|
raise ImportError( |
|
"Please install phonemizer: pip install g2p_en" |
|
) |
|
else: |
|
try: |
|
from phonemizer import phonemize |
|
from phonemizer.separator import Separator |
|
return phonemize( |
|
text, backend='espeak', language=lang, |
|
separator=Separator(word="| ", phone=" ") |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"Please install phonemizer: pip install phonemizer" |
|
) |
|
|
|
|
|
@dataclass |
|
class ForceAlignmentInfo(object): |
|
tokens: List[str] |
|
frame_durations: List[int] |
|
start_sec: Optional[float] |
|
end_sec: Optional[float] |
|
|
|
|
|
def get_mfa_alignment_by_sample_id( |
|
textgrid_zip_path: str, sample_id: str, sample_rate: int, |
|
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn") |
|
) -> ForceAlignmentInfo: |
|
try: |
|
import tgt |
|
except ImportError: |
|
raise ImportError("Please install TextGridTools: pip install tgt") |
|
|
|
filename = f"{sample_id}.TextGrid" |
|
out_root = Path(tempfile.gettempdir()) |
|
tgt_path = out_root / filename |
|
with zipfile.ZipFile(textgrid_zip_path) as f_zip: |
|
f_zip.extract(filename, path=out_root) |
|
textgrid = tgt.io.read_textgrid(tgt_path.as_posix()) |
|
os.remove(tgt_path) |
|
|
|
phones, frame_durations = [], [] |
|
start_sec, end_sec, end_idx = 0, 0, 0 |
|
for t in textgrid.get_tier_by_name("phones")._objects: |
|
s, e, p = t.start_time, t.end_time, t.text |
|
|
|
if len(phones) == 0: |
|
if p in silence_phones: |
|
continue |
|
else: |
|
start_sec = s |
|
phones.append(p) |
|
if p not in silence_phones: |
|
end_sec = e |
|
end_idx = len(phones) |
|
r = sample_rate / hop_length |
|
frame_durations.append(int(np.round(e * r) - np.round(s * r))) |
|
|
|
phones = phones[:end_idx] |
|
frame_durations = frame_durations[:end_idx] |
|
|
|
return ForceAlignmentInfo( |
|
tokens=phones, frame_durations=frame_durations, start_sec=start_sec, |
|
end_sec=end_sec |
|
) |
|
|
|
|
|
def get_mfa_alignment( |
|
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int, |
|
hop_length: int |
|
) -> Dict[str, ForceAlignmentInfo]: |
|
return { |
|
i: get_mfa_alignment_by_sample_id( |
|
textgrid_zip_path, i, sample_rate, hop_length |
|
) for i in tqdm(sample_ids) |
|
} |
|
|
|
|
|
def get_unit_alignment( |
|
id_to_unit_tsv_path: str, sample_ids: List[str] |
|
) -> Dict[str, ForceAlignmentInfo]: |
|
id_to_units = { |
|
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path) |
|
} |
|
id_to_units = {i: id_to_units[i].split() for i in sample_ids} |
|
id_to_units_collapsed = { |
|
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items() |
|
} |
|
id_to_durations = { |
|
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items() |
|
} |
|
|
|
return { |
|
i: ForceAlignmentInfo( |
|
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i], |
|
start_sec=None, end_sec=None |
|
) |
|
for i in sample_ids |
|
} |
|
|