crystal-technologies's picture
Upload 1287 files
2d8da09
raw
history blame
5.15 kB
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is to compute speaker-level statistics, such as pitch mean & standard deviation, for a given
TTS training manifest.
This script should be run after extract_sup_data.py as it uses the precomputed supplemental features.
$ python <nemo_root_path>/scripts/dataset_processing/tts/compute_speaker_stats.py \
--manifest_path=<data_root_path>/fastpitch_manifest.json \
--sup_data_path=<data_root_path>/sup_data \
--pitch_stats_path=<data_root_path>/pitch_stats.json
"""
import argparse
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple
import torch
from tqdm import tqdm
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir
from nemo.collections.tts.torch.tts_data_types import Pitch
from nemo.utils import logging
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute speaker level pitch statistics.",
)
parser.add_argument(
"--manifest_path", required=True, type=Path, help="Path to training manifest.",
)
parser.add_argument(
"--sup_data_path", default=Path("sup_data"), type=Path, help="Path to base directory with supplementary data.",
)
parser.add_argument(
"--pitch_stats_path",
default=Path("pitch_stats.json"),
type=Path,
help="Path to output JSON file with speaker pitch statistics.",
)
args = parser.parse_args()
return args
def _compute_stats(values: List[torch.Tensor]) -> Tuple[float, float]:
values_tensor = torch.cat(values, dim=0)
mean = values_tensor.mean().item()
std = values_tensor.std(dim=0).item()
return mean, std
def _get_sup_data_filepath(manifest_entry: dict, audio_dir: Path, sup_data_dir: Path) -> Path:
"""
Get the absolute path of a supplementary data type for the input manifest entry.
Example: audio_filepath "<audio_dir>/speaker1/audio1.wav" becomes "<sup_data_dir>/speaker1_audio1.pt"
Args:
manifest_entry: Manifest entry dictionary.
audio_dir: base directory where audio is stored.
sup_data_dir: base directory where supplementary data is stored.
Returns:
Path to the supplementary data file.
"""
audio_path = Path(manifest_entry["audio_filepath"])
rel_audio_path = audio_path.relative_to(audio_dir)
rel_sup_data_path = rel_audio_path.with_suffix(".pt")
sup_data_filename = str(rel_sup_data_path).replace(os.sep, "_")
sup_data_filepath = sup_data_dir / sup_data_filename
return sup_data_filepath
def main():
args = get_args()
manifest_path = args.manifest_path
sup_data_path = args.sup_data_path
pitch_stats_path = args.pitch_stats_path
pitch_data_path = Path(os.path.join(sup_data_path, Pitch.name))
if not os.path.exists(pitch_data_path):
raise ValueError(
f"Pitch directory {pitch_data_path} does not exist. Make sure 'sup_data_path' is correct "
f"and that you have computed the pitch using extract_sup_data.py"
)
entries = read_manifest(manifest_path)
audio_paths = [entry["audio_filepath"] for entry in entries]
base_dir = get_base_dir(audio_paths)
global_pitch_values = []
speaker_pitch_values = defaultdict(list)
for entry in tqdm(entries):
pitch_path = _get_sup_data_filepath(manifest_entry=entry, audio_dir=base_dir, sup_data_dir=pitch_data_path)
if not os.path.exists(pitch_path):
logging.warning(f"Unable to find pitch file for {entry}")
continue
pitch = torch.load(pitch_path)
# Filter out non-speech frames
pitch = pitch[pitch != 0]
global_pitch_values.append(pitch)
if "speaker" in entry:
speaker_id = entry["speaker"]
speaker_pitch_values[speaker_id].append(pitch)
global_pitch_mean, global_pitch_std = _compute_stats(global_pitch_values)
pitch_stats = {"default": {"pitch_mean": global_pitch_mean, "pitch_std": global_pitch_std}}
for speaker_id, pitch_values in speaker_pitch_values.items():
pitch_mean, pitch_std = _compute_stats(pitch_values)
pitch_stats[speaker_id] = {"pitch_mean": pitch_mean, "pitch_std": pitch_std}
with open(pitch_stats_path, 'w', encoding="utf-8") as stats_f:
json.dump(pitch_stats, stats_f, indent=4)
if __name__ == "__main__":
main()