|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script is to compute speaker-level statistics, such as pitch mean & standard deviation, for a given |
|
TTS training manifest. |
|
|
|
This script should be run after extract_sup_data.py as it uses the precomputed supplemental features. |
|
|
|
$ python <nemo_root_path>/scripts/dataset_processing/tts/compute_speaker_stats.py \ |
|
--manifest_path=<data_root_path>/fastpitch_manifest.json \ |
|
--sup_data_path=<data_root_path>/sup_data \ |
|
--pitch_stats_path=<data_root_path>/pitch_stats.json |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from typing import List, Tuple |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest |
|
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir |
|
from nemo.collections.tts.torch.tts_data_types import Pitch |
|
from nemo.utils import logging |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute speaker level pitch statistics.", |
|
) |
|
parser.add_argument( |
|
"--manifest_path", required=True, type=Path, help="Path to training manifest.", |
|
) |
|
parser.add_argument( |
|
"--sup_data_path", default=Path("sup_data"), type=Path, help="Path to base directory with supplementary data.", |
|
) |
|
parser.add_argument( |
|
"--pitch_stats_path", |
|
default=Path("pitch_stats.json"), |
|
type=Path, |
|
help="Path to output JSON file with speaker pitch statistics.", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def _compute_stats(values: List[torch.Tensor]) -> Tuple[float, float]: |
|
values_tensor = torch.cat(values, dim=0) |
|
mean = values_tensor.mean().item() |
|
std = values_tensor.std(dim=0).item() |
|
return mean, std |
|
|
|
|
|
def _get_sup_data_filepath(manifest_entry: dict, audio_dir: Path, sup_data_dir: Path) -> Path: |
|
""" |
|
Get the absolute path of a supplementary data type for the input manifest entry. |
|
|
|
Example: audio_filepath "<audio_dir>/speaker1/audio1.wav" becomes "<sup_data_dir>/speaker1_audio1.pt" |
|
|
|
Args: |
|
manifest_entry: Manifest entry dictionary. |
|
audio_dir: base directory where audio is stored. |
|
sup_data_dir: base directory where supplementary data is stored. |
|
|
|
Returns: |
|
Path to the supplementary data file. |
|
""" |
|
audio_path = Path(manifest_entry["audio_filepath"]) |
|
rel_audio_path = audio_path.relative_to(audio_dir) |
|
rel_sup_data_path = rel_audio_path.with_suffix(".pt") |
|
sup_data_filename = str(rel_sup_data_path).replace(os.sep, "_") |
|
sup_data_filepath = sup_data_dir / sup_data_filename |
|
return sup_data_filepath |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
manifest_path = args.manifest_path |
|
sup_data_path = args.sup_data_path |
|
pitch_stats_path = args.pitch_stats_path |
|
|
|
pitch_data_path = Path(os.path.join(sup_data_path, Pitch.name)) |
|
if not os.path.exists(pitch_data_path): |
|
raise ValueError( |
|
f"Pitch directory {pitch_data_path} does not exist. Make sure 'sup_data_path' is correct " |
|
f"and that you have computed the pitch using extract_sup_data.py" |
|
) |
|
|
|
entries = read_manifest(manifest_path) |
|
|
|
audio_paths = [entry["audio_filepath"] for entry in entries] |
|
base_dir = get_base_dir(audio_paths) |
|
|
|
global_pitch_values = [] |
|
speaker_pitch_values = defaultdict(list) |
|
for entry in tqdm(entries): |
|
pitch_path = _get_sup_data_filepath(manifest_entry=entry, audio_dir=base_dir, sup_data_dir=pitch_data_path) |
|
if not os.path.exists(pitch_path): |
|
logging.warning(f"Unable to find pitch file for {entry}") |
|
continue |
|
|
|
pitch = torch.load(pitch_path) |
|
|
|
pitch = pitch[pitch != 0] |
|
global_pitch_values.append(pitch) |
|
if "speaker" in entry: |
|
speaker_id = entry["speaker"] |
|
speaker_pitch_values[speaker_id].append(pitch) |
|
|
|
global_pitch_mean, global_pitch_std = _compute_stats(global_pitch_values) |
|
pitch_stats = {"default": {"pitch_mean": global_pitch_mean, "pitch_std": global_pitch_std}} |
|
for speaker_id, pitch_values in speaker_pitch_values.items(): |
|
pitch_mean, pitch_std = _compute_stats(pitch_values) |
|
pitch_stats[speaker_id] = {"pitch_mean": pitch_mean, "pitch_std": pitch_std} |
|
|
|
with open(pitch_stats_path, 'w', encoding="utf-8") as stats_f: |
|
json.dump(pitch_stats, stats_f, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|