import argparse |
import json |
import os |
import time |
from multiprocessing import Pool |
from pathlib import Path |
import hydra.utils |
import librosa |
import numpy as np |
import torch |
from omegaconf import open_dict |
from tqdm import tqdm |
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment |
from nemo.collections.tts.models import ssl_tts |
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir |
from nemo.core.classes import Dataset |
from nemo.utils import logging |
class AudioDataset(Dataset): |
def __init__( |
self, |
manifest_paths, |
min_duration=0.5, |
max_duration=16.0, |
pad_multiple=1024, |
sample_rate=22050, |
sup_data_dir=None, |
): |
self.data = [] |
for manifest_path in manifest_paths: |
with open(manifest_path, "r") as f: |
for line in f: |
record = json.loads(line) |
if record['duration'] < min_duration or record['duration'] > max_duration: |
continue |
self.data.append(json.loads(line)) |
self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data]) |
if sup_data_dir is not None: |
self.sup_data_dir = sup_data_dir |
else: |
self.sup_data_dir = os.path.join(self.base_data_dir, "sup_data") |
if not os.path.exists(self.sup_data_dir): |
os.makedirs(self.sup_data_dir) |
self.pad_multiple = pad_multiple |
self.sample_rate = sample_rate |
def __len__(self): |
return len(self.data) |
def _get_wav_from_filepath(self, audio_filepath): |
features = AudioSegment.segment_from_file( |
audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False, |
) |
audio_samples = features.samples |
audio, audio_length = torch.tensor(audio_samples), torch.tensor(audio_samples.shape[0]).long() |
if audio.shape[0] % self.pad_multiple != 0: |
audio = torch.cat( |
[audio, torch.zeros(self.pad_multiple - audio.shape[0] % self.pad_multiple, dtype=torch.float)] |
) |
audio_length = torch.tensor(audio.shape[0]).long() |
return audio, audio_length |
def pad_collate_fn(self, batch): |
final_batch = {} |
for row in batch: |
for key in row: |
if key not in final_batch: |
final_batch[key] = [] |
final_batch[key].append(row[key]) |
max_audio_len = max([_audio_len.item() for _audio_len in final_batch["audio_len"]]) |
audios_padded = [] |
for audio in final_batch["audio"]: |
audio_padded = torch.nn.functional.pad(audio, (0, max_audio_len - audio.size(0)), value=0) |
audios_padded.append(audio_padded) |
final_batch["audio"] = audios_padded |
for key in final_batch: |
if key not in ["rel_audio_path_as_text_id", "wav_path"]: |
final_batch[key] = torch.stack(final_batch[key]) |
return final_batch |
def __getitem__(self, index): |
sample = self.data[index] |
rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("") |
rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") |
speaker = torch.tensor(sample["speaker"]).long() |
audio, audio_length = self._get_wav_from_filepath(sample["audio_filepath"]) |
return { |
"audio": audio, |
"audio_len": audio_length, |
"rel_audio_path_as_text_id": rel_audio_path_as_text_id, |
"wav_path": sample["audio_filepath"], |
"speaker": speaker, |
} |
def segment_wav(wav, segment_length, segment_hop_size, min_segment_length): |
if len(wav) < segment_length: |
pad = torch.zeros(segment_length - len(wav)) |
segment = torch.cat([wav, pad]) |
return [segment] |
else: |
si = 0 |
segments = [] |
while si < len(wav) - min_segment_length: |
segment = wav[si : si + segment_length] |
if len(segment) < segment_length: |
pad = torch.zeros(segment_length - len(segment)) |
segment = torch.cat([segment, pad]) |
segments.append(segment) |
si += segment_hop_size |
return segments |
def segment_batch(batch, segment_length=44100, segment_hop_size=22050, min_segment_length=22050): |
all_segments = [] |
segment_indices = [] |
si = 0 |
for bidx in range(len(batch['audio'])): |
audio = batch['audio'][bidx] |
audio_length = batch['audio_len'][bidx] |
audio_actual = audio[:audio_length] |
audio_segments = segment_wav(audio_actual, segment_length, segment_hop_size, min_segment_length) |
all_segments += audio_segments |
segment_indices.append((si, si + len(audio_segments) - 1)) |
si += len(audio_segments) |
return torch.stack(all_segments), segment_indices |
def get_mel_spectrogram(fb, wav, stft_params): |
EPSILON = 1e-9 |
window_fn = torch.hann_window |
spec = torch.stft( |
input=wav, |
n_fft=stft_params['n_fft'], |
hop_length=stft_params['hop_length'], |
win_length=stft_params['win_length'], |
window=window_fn(stft_params['win_length'], periodic=False).to(torch.float).to('cuda') if window_fn else None, |
return_complex=True, |
center=True, |
) |
if spec.dtype in [torch.cfloat, torch.cdouble]: |
spec = torch.view_as_real(spec) |
spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON) |
mel = torch.matmul(fb.to(spec.dtype), spec) |
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny)) |
return log_mel |
def load_wav(wav_path, sample_rate=22050, pad_multiple=1024): |
wav = AudioSegment.segment_from_file(wav_path, target_sr=sample_rate, n_segments=-1, trim=False,).samples |
if wav.shape[0] % pad_multiple != 0: |
wav = np.concatenate([wav, np.zeros(pad_multiple - wav.shape[0] % pad_multiple)]) |
wav = wav[:-1] |
return wav |
def save_pitch_contour(record): |
wav_path = record['wav_path'] |
wav_text_id = record['wav_id'] |
sup_data_dir = record['sup_data_dir'] |
stft_params = record['stft_params'] |
wav = load_wav(wav_path, stft_params['sample_rate'], stft_params['pad_multiple']) |
pitch_contour_fn = f"pitch_contour_{wav_text_id}.pt" |
pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn) |
f0, _, _ = librosa.pyin( |
wav, |
fmin=librosa.note_to_hz('C2'), |
fmax=stft_params['yin_fmax'], |
frame_length=stft_params['win_length'], |
hop_length=stft_params['hop_length'], |
sr=stft_params['sample_rate'], |
center=True, |
fill_na=0.0, |
) |
pitch_contour = torch.tensor(f0, dtype=torch.float32) |
torch.save(pitch_contour, pitch_contour_fp) |
logging.info("saved {}".format(pitch_contour_fp)) |
return pitch_contour |
def compute_pitch_stats(records): |
def _is_valid_pitch(pitch_mean, pitch_std): |
c1 = pitch_mean > 0 and pitch_mean < 1000 |
c2 = pitch_std > 0 and pitch_std < 1000 |
return c1 and c2 |
speaker_wise_pitch_contours = {} |
for item in records: |
wav_id = item['wav_id'] |
speaker = item['speaker'] |
sup_data_dir = item['sup_data_dir'] |
pitch_contour_fn = f"pitch_contour_{wav_id}.pt" |
pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn) |
if speaker not in speaker_wise_pitch_contours: |
speaker_wise_pitch_contours[speaker] = [] |
speaker_wise_pitch_contours[speaker].append(pitch_contour_fp) |
speaker_pitch_stats = {} |
for speaker in speaker_wise_pitch_contours: |
non_zero_pc = [] |
for pitch_contour_fp in speaker_wise_pitch_contours[speaker][:50]: |
pitch_contour = torch.load(pitch_contour_fp) |
pitch_contour_nonzero = pitch_contour[pitch_contour != 0] |
if len(pitch_contour_nonzero) > 0: |
non_zero_pc.append(pitch_contour_nonzero) |
if len(non_zero_pc) > 0: |
non_zero_pc = torch.cat(non_zero_pc) |
pitch_mean = non_zero_pc.mean().item() |
pitch_std = non_zero_pc.std().item() |
valid = True |
if not _is_valid_pitch(pitch_mean, pitch_std): |
logging.warning("invalid pitch: {}".format(speaker)) |
pitch_mean = 212.0 |
pitch_std = 70.0 |
valid = "False" |
else: |
logging.warning("could not find pitch contour for speaker {}".format(speaker)) |
valid = "False" |
pitch_mean = 212.0 |
pitch_std = 70.0 |
speaker_pitch_stats[speaker] = {"pitch_mean": pitch_mean, "pitch_std": pitch_std, "valid": valid} |
with open(os.path.join(sup_data_dir, "speaker_pitch_stats.json"), "w") as f: |
json.dump(speaker_pitch_stats, f) |
def main(): |
parser = argparse.ArgumentParser(description='Evaluate the model') |
parser.add_argument( |
'--ssl_model_ckpt_path', type=str, required=True, |
) |
parser.add_argument('--manifest_paths', type=str, required=True) |
parser.add_argument('--sup_data_dir', type=str, default=None) |
parser.add_argument('--batch_size', type=int, default=32) |
parser.add_argument('--ssl_content_emb_type', type=str, default="embedding_and_probs") |
parser.add_argument('--use_unique_tokens', type=int, default=1) |
parser.add_argument('--num_workers', type=int, default=8) |
parser.add_argument('--pool_workers', type=int, default=30) |
parser.add_argument('--compute_pitch_contours', type=int, default=1) |
parser.add_argument('--num_pitch_per_speaker', type=int, default=None) |
parser.add_argument('--sample_rate', type=int, default=22050) |
parser.add_argument('--pad_multiple', type=int, default=1024) |
parser.add_argument('--ssl_downsampling_factor', type=int, default=4) |
parser.add_argument('--stft_n_fft', type=int, default=1024) |
parser.add_argument('--stft_hop_length', type=int, default=256) |
parser.add_argument('--stft_win_length', type=int, default=1024) |
parser.add_argument('--stft_n_mel', type=int, default=80) |
parser.add_argument('--stft_fmin', type=int, default=0) |
parser.add_argument('--stft_fmax', type=int, default=8000) |
parser.add_argument('--yin_fmax', type=int, default=500) |
parser.add_argument('--segment_length', type=int, default=44100) |
parser.add_argument('--segment_hop_size', type=int, default=22050) |
parser.add_argument('--min_segment_length', type=int, default=22050) |
args = parser.parse_args() |
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
manifest_paths = args.manifest_paths.split(",") |
ssl_model_ckpt_path = args.ssl_model_ckpt_path |
dataset = AudioDataset( |
manifest_paths, pad_multiple=args.pad_multiple, sample_rate=args.sample_rate, sup_data_dir=args.sup_data_dir |
) |
dataloader = torch.utils.data.DataLoader( |
dataset, |
batch_size=args.batch_size, |
shuffle=False, |
collate_fn=dataset.pad_collate_fn, |
num_workers=args.num_workers, |
) |
ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(ssl_model_ckpt_path, strict=False) |
with open_dict(ssl_model.cfg): |
ssl_model.cfg.preprocessor.exact_pad = True |
ssl_model.preprocessor = hydra.utils.instantiate(ssl_model.cfg.preprocessor) |
ssl_model.preprocessor_disentangler = ssl_model.preprocessor |
ssl_model.eval() |
ssl_model.to(device) |
sample_rate = args.sample_rate |
stft_params = { |
"n_fft": args.stft_n_fft, |
"hop_length": args.stft_hop_length, |
"win_length": args.stft_win_length, |
"n_mel": args.stft_n_mel, |
"sample_rate": sample_rate, |
"pad_multiple": args.pad_multiple, |
"fmin": args.stft_fmin, |
"fmax": args.stft_fmax, |
"yin_fmax": args.yin_fmax, |
} |
fb = ( |
torch.tensor( |
librosa.filters.mel( |
sr=sample_rate, |
n_fft=stft_params['n_fft'], |
n_mels=stft_params['n_mel'], |
fmin=stft_params['fmin'], |
fmax=stft_params['fmax'], |
), |
dtype=torch.float, |
) |
.unsqueeze(0) |
.to(device) |
) |
st = time.time() |
bidx = 0 |
wav_and_id_list = [] |
for batch in tqdm(dataloader): |
bidx += 1 |
with torch.no_grad(): |
( |
_, |
_, |
batch_content_embedding, |
batch_content_log_probs, |
batch_encoded_len, |
) = ssl_model.forward_for_export( |
input_signal=batch['audio'].to(device), |
input_signal_length=batch['audio_len'].to(device), |
normalize_content=True, |
) |
batch_mel_specs = get_mel_spectrogram(fb, batch['audio'][:, :-1].to(device), stft_params) |
audio_segmented, segment_indices = segment_batch( |
batch, args.segment_length, args.segment_hop_size, args.min_segment_length |
) |
audio_seg_len = torch.tensor([len(segment) for segment in audio_segmented]).to(device).long() |
_, batch_speaker_embeddings, _, _, _ = ssl_model.forward_for_export( |
input_signal=audio_segmented.to(device), input_signal_length=audio_seg_len, normalize_content=True, |
) |
for idx in range(batch['audio'].shape[0]): |
_speaker = batch['speaker'][idx].item() |
wav_path = batch['wav_path'][idx] |
wav_id = batch['rel_audio_path_as_text_id'][idx] |
wav_and_id_list.append((wav_path, wav_id, _speaker)) |
content_embedding = batch_content_embedding[idx].detach() |
content_log_probs = batch_content_log_probs[:, idx, :].detach() |
encoded_len = batch_encoded_len[idx].detach() |
content_embedding = content_embedding[: encoded_len.item()] |
content_embedding = content_embedding.t() |
content_log_probs = content_log_probs[: encoded_len.item()] |
content_log_probs = content_log_probs.t() |
content_probs = torch.exp(content_log_probs) |
duration = torch.ones(content_embedding.shape[1]) * args.ssl_downsampling_factor |
bsi_start = segment_indices[idx][0] |
bsi_end = segment_indices[idx][1] |
speaker_embedding = torch.mean(batch_speaker_embeddings[bsi_start : bsi_end + 1], dim=0) |
l2_norm = torch.norm(speaker_embedding, p=2) |
speaker_embedding = speaker_embedding / l2_norm |
if args.ssl_content_emb_type == "probs": |
final_content_embedding = content_probs |
elif args.ssl_content_emb_type == "embedding": |
final_content_embedding = content_embedding |
elif args.ssl_content_emb_type == "log_probs": |
final_content_embedding = content_log_probs |
elif args.ssl_content_emb_type == "embedding_and_probs": |
final_content_embedding = torch.cat([content_embedding, content_probs], dim=0) |
if args.use_unique_tokens == 1: |
token_predictions = torch.argmax(content_probs, dim=0) |
content_buffer = [final_content_embedding[:, 0]] |
unique_content_embeddings = [] |
unique_tokens = [] |
durations = [] |
for _t in range(1, final_content_embedding.shape[1]): |
if token_predictions[_t] == token_predictions[_t - 1]: |
content_buffer.append(final_content_embedding[:, _t]) |
else: |
durations.append(len(content_buffer) * args.ssl_downsampling_factor) |
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) |
content_buffer = [final_content_embedding[:, _t]] |
unique_tokens.append(token_predictions[_t].item()) |
if len(content_buffer) > 0: |
durations.append(len(content_buffer) * args.ssl_downsampling_factor) |
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) |
unique_tokens.append(token_predictions[_t].item()) |
unique_content_embedding = torch.stack(unique_content_embeddings) |
final_content_embedding = unique_content_embedding.t() |
duration = torch.tensor(durations).float() |
mel_len = int(batch['audio_len'][idx].item() / stft_params['hop_length']) |
item_mel = batch_mel_specs[idx][:, :mel_len] |
wav_text_id = batch["rel_audio_path_as_text_id"][idx] |
content_emb_fn = f"{args.ssl_content_emb_type}_content_embedding_{wav_text_id}.pt" |
speaker_emb_fn = f"speaker_embedding_{wav_text_id}.pt" |
duration_fn = f"duration_embedding_{wav_text_id}.pt" |
content_emb_fp = os.path.join(dataset.sup_data_dir, content_emb_fn) |
speaker_emb_fp = os.path.join(dataset.sup_data_dir, speaker_emb_fn) |
duration_fp = os.path.join(dataset.sup_data_dir, duration_fn) |
mel_spec_fn = f"mel_spec_{wav_text_id}.pt" |
mel_spec_fp = os.path.join(dataset.sup_data_dir, mel_spec_fn) |
torch.save(item_mel.cpu(), mel_spec_fp) |
torch.save(final_content_embedding.cpu(), content_emb_fp) |
torch.save(speaker_embedding.cpu(), speaker_emb_fp) |
torch.save(duration.cpu(), duration_fp) |
et = time.time() |
logging.info( |
"Processed Batch {} of {} | Time per batch: {:.4f} s".format( |
bidx + 1, len(dataloader), (et - st) / bidx |
) |
) |
if args.compute_pitch_contours == 1: |
speaker_wise_records = {} |
for row in wav_and_id_list: |
wav_path, wav_id, speaker = row |
if speaker not in speaker_wise_records: |
speaker_wise_records[speaker] = [] |
speaker_wise_records[speaker].append( |
{ |
"wav_path": wav_path, |
"wav_id": wav_id, |
"sup_data_dir": dataset.sup_data_dir, |
"stft_params": stft_params, |
"speaker": speaker, |
} |
) |
filtered_records = [] |
for speaker in speaker_wise_records: |
if args.num_pitch_per_speaker is not None: |
filtered_records += speaker_wise_records[speaker][: args.num_pitch_per_speaker] |
else: |
filtered_records += speaker_wise_records[speaker] |
with Pool(args.pool_workers) as p: |
p.map(save_pitch_contour, filtered_records) |
compute_pitch_stats(filtered_records) |
if __name__ == '__main__': |
main() |