crystal-technologies's picture
Upload 1287 files
2d8da09
raw
history blame
20.8 kB
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example Run Command: python make_supdata.py --ssl_model_ckpt_path <PATH TO CKPT> --manifest_path <PATH TO MANIFEST>
import argparse
import json
import os
import time
from multiprocessing import Pool
from pathlib import Path
import hydra.utils
import librosa
import numpy as np
import torch
from omegaconf import open_dict
from tqdm import tqdm
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.tts.models import ssl_tts
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir
from nemo.core.classes import Dataset
from nemo.utils import logging
class AudioDataset(Dataset):
def __init__(
self,
manifest_paths,
min_duration=0.5,
max_duration=16.0,
pad_multiple=1024,
sample_rate=22050,
sup_data_dir=None,
):
self.data = []
for manifest_path in manifest_paths:
with open(manifest_path, "r") as f:
for line in f:
record = json.loads(line)
if record['duration'] < min_duration or record['duration'] > max_duration:
continue
self.data.append(json.loads(line))
self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data])
if sup_data_dir is not None:
self.sup_data_dir = sup_data_dir
else:
self.sup_data_dir = os.path.join(self.base_data_dir, "sup_data")
if not os.path.exists(self.sup_data_dir):
os.makedirs(self.sup_data_dir)
self.pad_multiple = pad_multiple
self.sample_rate = sample_rate
def __len__(self):
return len(self.data)
def _get_wav_from_filepath(self, audio_filepath):
features = AudioSegment.segment_from_file(
audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False,
)
audio_samples = features.samples
audio, audio_length = torch.tensor(audio_samples), torch.tensor(audio_samples.shape[0]).long()
# pad audio to a multiple of self.pad_multiple
if audio.shape[0] % self.pad_multiple != 0:
audio = torch.cat(
[audio, torch.zeros(self.pad_multiple - audio.shape[0] % self.pad_multiple, dtype=torch.float)]
)
audio_length = torch.tensor(audio.shape[0]).long()
return audio, audio_length
def pad_collate_fn(self, batch):
final_batch = {}
for row in batch:
for key in row:
if key not in final_batch:
final_batch[key] = []
final_batch[key].append(row[key])
max_audio_len = max([_audio_len.item() for _audio_len in final_batch["audio_len"]])
audios_padded = []
for audio in final_batch["audio"]:
audio_padded = torch.nn.functional.pad(audio, (0, max_audio_len - audio.size(0)), value=0)
audios_padded.append(audio_padded)
final_batch["audio"] = audios_padded
for key in final_batch:
if key not in ["rel_audio_path_as_text_id", "wav_path"]:
final_batch[key] = torch.stack(final_batch[key])
return final_batch
def __getitem__(self, index):
sample = self.data[index]
rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("")
rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_")
speaker = torch.tensor(sample["speaker"]).long()
audio, audio_length = self._get_wav_from_filepath(sample["audio_filepath"])
return {
"audio": audio,
"audio_len": audio_length,
"rel_audio_path_as_text_id": rel_audio_path_as_text_id,
"wav_path": sample["audio_filepath"],
"speaker": speaker,
}
def segment_wav(wav, segment_length, segment_hop_size, min_segment_length):
if len(wav) < segment_length:
pad = torch.zeros(segment_length - len(wav))
segment = torch.cat([wav, pad])
return [segment]
else:
si = 0
segments = []
while si < len(wav) - min_segment_length:
segment = wav[si : si + segment_length]
if len(segment) < segment_length:
pad = torch.zeros(segment_length - len(segment))
segment = torch.cat([segment, pad])
segments.append(segment)
si += segment_hop_size
return segments
def segment_batch(batch, segment_length=44100, segment_hop_size=22050, min_segment_length=22050):
all_segments = []
segment_indices = []
si = 0
for bidx in range(len(batch['audio'])):
audio = batch['audio'][bidx]
audio_length = batch['audio_len'][bidx]
audio_actual = audio[:audio_length]
audio_segments = segment_wav(audio_actual, segment_length, segment_hop_size, min_segment_length)
all_segments += audio_segments
segment_indices.append((si, si + len(audio_segments) - 1))
si += len(audio_segments)
return torch.stack(all_segments), segment_indices
def get_mel_spectrogram(fb, wav, stft_params):
EPSILON = 1e-9
window_fn = torch.hann_window
spec = torch.stft(
input=wav,
n_fft=stft_params['n_fft'], # 1024
hop_length=stft_params['hop_length'], # 256
win_length=stft_params['win_length'], # 1024
window=window_fn(stft_params['win_length'], periodic=False).to(torch.float).to('cuda') if window_fn else None,
return_complex=True,
center=True,
)
if spec.dtype in [torch.cfloat, torch.cdouble]:
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON)
mel = torch.matmul(fb.to(spec.dtype), spec)
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
return log_mel
def load_wav(wav_path, sample_rate=22050, pad_multiple=1024):
wav = AudioSegment.segment_from_file(wav_path, target_sr=sample_rate, n_segments=-1, trim=False,).samples
if wav.shape[0] % pad_multiple != 0:
wav = np.concatenate([wav, np.zeros(pad_multiple - wav.shape[0] % pad_multiple)])
wav = wav[:-1]
return wav
def save_pitch_contour(record):
wav_path = record['wav_path']
wav_text_id = record['wav_id']
sup_data_dir = record['sup_data_dir']
stft_params = record['stft_params']
wav = load_wav(wav_path, stft_params['sample_rate'], stft_params['pad_multiple'])
pitch_contour_fn = f"pitch_contour_{wav_text_id}.pt"
pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn)
f0, _, _ = librosa.pyin(
wav,
fmin=librosa.note_to_hz('C2'),
fmax=stft_params['yin_fmax'],
frame_length=stft_params['win_length'],
hop_length=stft_params['hop_length'],
sr=stft_params['sample_rate'],
center=True,
fill_na=0.0,
)
pitch_contour = torch.tensor(f0, dtype=torch.float32)
torch.save(pitch_contour, pitch_contour_fp)
logging.info("saved {}".format(pitch_contour_fp))
return pitch_contour
def compute_pitch_stats(records):
def _is_valid_pitch(pitch_mean, pitch_std):
c1 = pitch_mean > 0 and pitch_mean < 1000
c2 = pitch_std > 0 and pitch_std < 1000
return c1 and c2
speaker_wise_pitch_contours = {}
for item in records:
wav_id = item['wav_id']
speaker = item['speaker']
sup_data_dir = item['sup_data_dir']
pitch_contour_fn = f"pitch_contour_{wav_id}.pt"
pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn)
if speaker not in speaker_wise_pitch_contours:
speaker_wise_pitch_contours[speaker] = []
speaker_wise_pitch_contours[speaker].append(pitch_contour_fp)
speaker_pitch_stats = {}
for speaker in speaker_wise_pitch_contours:
non_zero_pc = []
for pitch_contour_fp in speaker_wise_pitch_contours[speaker][:50]:
pitch_contour = torch.load(pitch_contour_fp)
pitch_contour_nonzero = pitch_contour[pitch_contour != 0]
if len(pitch_contour_nonzero) > 0:
non_zero_pc.append(pitch_contour_nonzero)
if len(non_zero_pc) > 0:
non_zero_pc = torch.cat(non_zero_pc)
pitch_mean = non_zero_pc.mean().item()
pitch_std = non_zero_pc.std().item()
valid = True
if not _is_valid_pitch(pitch_mean, pitch_std):
logging.warning("invalid pitch: {}".format(speaker))
pitch_mean = 212.0
pitch_std = 70.0
valid = "False"
else:
logging.warning("could not find pitch contour for speaker {}".format(speaker))
valid = "False"
pitch_mean = 212.0
pitch_std = 70.0
speaker_pitch_stats[speaker] = {"pitch_mean": pitch_mean, "pitch_std": pitch_std, "valid": valid}
with open(os.path.join(sup_data_dir, "speaker_pitch_stats.json"), "w") as f:
json.dump(speaker_pitch_stats, f)
def main():
parser = argparse.ArgumentParser(description='Evaluate the model')
parser.add_argument(
'--ssl_model_ckpt_path', type=str, required=True,
)
parser.add_argument('--manifest_paths', type=str, required=True)
parser.add_argument('--sup_data_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--ssl_content_emb_type', type=str, default="embedding_and_probs")
parser.add_argument('--use_unique_tokens', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--pool_workers', type=int, default=30)
parser.add_argument('--compute_pitch_contours', type=int, default=1)
parser.add_argument('--num_pitch_per_speaker', type=int, default=None) # saves time.
parser.add_argument('--sample_rate', type=int, default=22050)
parser.add_argument('--pad_multiple', type=int, default=1024)
parser.add_argument('--ssl_downsampling_factor', type=int, default=4)
parser.add_argument('--stft_n_fft', type=int, default=1024)
parser.add_argument('--stft_hop_length', type=int, default=256)
parser.add_argument('--stft_win_length', type=int, default=1024)
parser.add_argument('--stft_n_mel', type=int, default=80)
parser.add_argument('--stft_fmin', type=int, default=0)
parser.add_argument('--stft_fmax', type=int, default=8000)
parser.add_argument('--yin_fmax', type=int, default=500)
parser.add_argument('--segment_length', type=int, default=44100)
parser.add_argument('--segment_hop_size', type=int, default=22050)
parser.add_argument('--min_segment_length', type=int, default=22050)
args = parser.parse_args()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
manifest_paths = args.manifest_paths.split(",")
ssl_model_ckpt_path = args.ssl_model_ckpt_path
dataset = AudioDataset(
manifest_paths, pad_multiple=args.pad_multiple, sample_rate=args.sample_rate, sup_data_dir=args.sup_data_dir
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=dataset.pad_collate_fn,
num_workers=args.num_workers,
)
ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(ssl_model_ckpt_path, strict=False)
with open_dict(ssl_model.cfg):
ssl_model.cfg.preprocessor.exact_pad = True
ssl_model.preprocessor = hydra.utils.instantiate(ssl_model.cfg.preprocessor)
ssl_model.preprocessor_disentangler = ssl_model.preprocessor
ssl_model.eval()
ssl_model.to(device)
sample_rate = args.sample_rate
stft_params = {
"n_fft": args.stft_n_fft,
"hop_length": args.stft_hop_length,
"win_length": args.stft_win_length,
"n_mel": args.stft_n_mel,
"sample_rate": sample_rate,
"pad_multiple": args.pad_multiple,
"fmin": args.stft_fmin,
"fmax": args.stft_fmax,
"yin_fmax": args.yin_fmax,
}
fb = (
torch.tensor(
librosa.filters.mel(
sr=sample_rate,
n_fft=stft_params['n_fft'],
n_mels=stft_params['n_mel'],
fmin=stft_params['fmin'],
fmax=stft_params['fmax'],
),
dtype=torch.float,
)
.unsqueeze(0)
.to(device)
)
st = time.time()
bidx = 0
wav_and_id_list = []
for batch in tqdm(dataloader):
bidx += 1
with torch.no_grad():
(
_,
_,
batch_content_embedding,
batch_content_log_probs,
batch_encoded_len,
) = ssl_model.forward_for_export(
input_signal=batch['audio'].to(device),
input_signal_length=batch['audio_len'].to(device),
normalize_content=True,
)
batch_mel_specs = get_mel_spectrogram(fb, batch['audio'][:, :-1].to(device), stft_params)
audio_segmented, segment_indices = segment_batch(
batch, args.segment_length, args.segment_hop_size, args.min_segment_length
)
audio_seg_len = torch.tensor([len(segment) for segment in audio_segmented]).to(device).long()
_, batch_speaker_embeddings, _, _, _ = ssl_model.forward_for_export(
input_signal=audio_segmented.to(device), input_signal_length=audio_seg_len, normalize_content=True,
)
for idx in range(batch['audio'].shape[0]):
_speaker = batch['speaker'][idx].item()
wav_path = batch['wav_path'][idx]
wav_id = batch['rel_audio_path_as_text_id'][idx]
wav_and_id_list.append((wav_path, wav_id, _speaker))
content_embedding = batch_content_embedding[idx].detach()
content_log_probs = batch_content_log_probs[:, idx, :].detach() # (content lob prob is (t, b, c))
encoded_len = batch_encoded_len[idx].detach()
content_embedding = content_embedding[: encoded_len.item()]
content_embedding = content_embedding.t()
content_log_probs = content_log_probs[: encoded_len.item()]
content_log_probs = content_log_probs.t()
content_probs = torch.exp(content_log_probs)
duration = torch.ones(content_embedding.shape[1]) * args.ssl_downsampling_factor
bsi_start = segment_indices[idx][0]
bsi_end = segment_indices[idx][1]
speaker_embedding = torch.mean(batch_speaker_embeddings[bsi_start : bsi_end + 1], dim=0)
l2_norm = torch.norm(speaker_embedding, p=2)
speaker_embedding = speaker_embedding / l2_norm
if args.ssl_content_emb_type == "probs":
# content embedding is only character probabilities
final_content_embedding = content_probs
elif args.ssl_content_emb_type == "embedding":
# content embedding is only output of content head of SSL backbone
final_content_embedding = content_embedding
elif args.ssl_content_emb_type == "log_probs":
# content embedding is only log of character probabilities
final_content_embedding = content_log_probs
elif args.ssl_content_emb_type == "embedding_and_probs":
# content embedding is the concatenation of character probabilities and output of content head of SSL backbone
final_content_embedding = torch.cat([content_embedding, content_probs], dim=0)
if args.use_unique_tokens == 1:
# group content embeddings with same predicted token (by averaging) and add the durations of the grouped embeddings
# Eg. By default each content embedding corresponds to 4 frames of spectrogram (ssl_downsampling_factor)
# If we group 3 content embeddings, the duration of the grouped embedding will be 12 frames.
# This is useful for adapting the duration during inference based on the speaker.
token_predictions = torch.argmax(content_probs, dim=0)
content_buffer = [final_content_embedding[:, 0]]
unique_content_embeddings = []
unique_tokens = []
durations = []
for _t in range(1, final_content_embedding.shape[1]):
if token_predictions[_t] == token_predictions[_t - 1]:
content_buffer.append(final_content_embedding[:, _t])
else:
durations.append(len(content_buffer) * args.ssl_downsampling_factor)
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
content_buffer = [final_content_embedding[:, _t]]
unique_tokens.append(token_predictions[_t].item())
if len(content_buffer) > 0:
durations.append(len(content_buffer) * args.ssl_downsampling_factor)
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
unique_tokens.append(token_predictions[_t].item())
unique_content_embedding = torch.stack(unique_content_embeddings)
final_content_embedding = unique_content_embedding.t()
duration = torch.tensor(durations).float()
mel_len = int(batch['audio_len'][idx].item() / stft_params['hop_length'])
item_mel = batch_mel_specs[idx][:, :mel_len]
wav_text_id = batch["rel_audio_path_as_text_id"][idx]
content_emb_fn = f"{args.ssl_content_emb_type}_content_embedding_{wav_text_id}.pt"
speaker_emb_fn = f"speaker_embedding_{wav_text_id}.pt"
duration_fn = f"duration_embedding_{wav_text_id}.pt" # embedding just for namesake
content_emb_fp = os.path.join(dataset.sup_data_dir, content_emb_fn)
speaker_emb_fp = os.path.join(dataset.sup_data_dir, speaker_emb_fn)
duration_fp = os.path.join(dataset.sup_data_dir, duration_fn)
mel_spec_fn = f"mel_spec_{wav_text_id}.pt"
mel_spec_fp = os.path.join(dataset.sup_data_dir, mel_spec_fn)
torch.save(item_mel.cpu(), mel_spec_fp)
torch.save(final_content_embedding.cpu(), content_emb_fp)
torch.save(speaker_embedding.cpu(), speaker_emb_fp)
torch.save(duration.cpu(), duration_fp)
et = time.time()
logging.info(
"Processed Batch {} of {} | Time per batch: {:.4f} s".format(
bidx + 1, len(dataloader), (et - st) / bidx
)
)
if args.compute_pitch_contours == 1:
speaker_wise_records = {}
for row in wav_and_id_list:
wav_path, wav_id, speaker = row
if speaker not in speaker_wise_records:
speaker_wise_records[speaker] = []
speaker_wise_records[speaker].append(
{
"wav_path": wav_path,
"wav_id": wav_id,
"sup_data_dir": dataset.sup_data_dir,
"stft_params": stft_params,
"speaker": speaker,
}
)
filtered_records = []
for speaker in speaker_wise_records:
if args.num_pitch_per_speaker is not None:
filtered_records += speaker_wise_records[speaker][: args.num_pitch_per_speaker]
else:
filtered_records += speaker_wise_records[speaker]
with Pool(args.pool_workers) as p:
p.map(save_pitch_contour, filtered_records)
compute_pitch_stats(filtered_records)
if __name__ == '__main__':
main()