|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import random |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torch.utils.data |
|
|
|
from . import data_utils |
|
from fairseq.data.fairseq_dataset import FairseqDataset |
|
|
|
F0_FRAME_SPACE = 0.005 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ExpressiveCodeDataConfig(object): |
|
def __init__(self, json_path): |
|
with open(json_path, "r") as f: |
|
self.config = json.load(f) |
|
self._manifests = self.config["manifests"] |
|
|
|
@property |
|
def manifests(self): |
|
return self._manifests |
|
|
|
@property |
|
def n_units(self): |
|
return self.config["n_units"] |
|
|
|
@property |
|
def sampling_rate(self): |
|
return self.config["sampling_rate"] |
|
|
|
@property |
|
def code_hop_size(self): |
|
return self.config["code_hop_size"] |
|
|
|
@property |
|
def f0_stats(self): |
|
"""pre-computed f0 statistics path""" |
|
return self.config.get("f0_stats", None) |
|
|
|
@property |
|
def f0_vq_type(self): |
|
"""naive or precomp""" |
|
return self.config["f0_vq_type"] |
|
|
|
@property |
|
def f0_vq_name(self): |
|
return self.config["f0_vq_name"] |
|
|
|
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std): |
|
key = "log" if log else "linear" |
|
if norm_mean and norm_std: |
|
key += "_mean_std_norm" |
|
elif norm_mean: |
|
key += "_mean_norm" |
|
else: |
|
key += "_none_norm" |
|
return self.config["f0_vq_naive_quantizer"][key] |
|
|
|
@property |
|
def f0_vq_n_units(self): |
|
return self.config["f0_vq_n_units"] |
|
|
|
@property |
|
def multispkr(self): |
|
"""how to parse speaker label from audio path""" |
|
return self.config.get("multispkr", None) |
|
|
|
|
|
def get_f0(audio, rate=16000): |
|
try: |
|
import amfm_decompy.basic_tools as basic |
|
import amfm_decompy.pYAAPT as pYAAPT |
|
from librosa.util import normalize |
|
except ImportError: |
|
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)." |
|
|
|
assert audio.ndim == 1 |
|
frame_length = 20.0 |
|
to_pad = int(frame_length / 1000 * rate) // 2 |
|
|
|
audio = normalize(audio) * 0.95 |
|
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) |
|
audio = basic.SignalObj(audio, rate) |
|
pitch = pYAAPT.yaapt( |
|
audio, |
|
frame_length=frame_length, |
|
frame_space=F0_FRAME_SPACE * 1000, |
|
nccf_thresh1=0.25, |
|
tda_frame_length=25.0, |
|
) |
|
f0 = pitch.samp_values |
|
return f0 |
|
|
|
|
|
def interpolate_f0(f0): |
|
try: |
|
from scipy.interpolate import interp1d |
|
except ImportError: |
|
raise "Please install scipy (`pip install scipy`)" |
|
|
|
orig_t = np.arange(f0.shape[0]) |
|
f0_interp = f0[:] |
|
ii = f0_interp != 0 |
|
if ii.sum() > 1: |
|
f0_interp = interp1d( |
|
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0 |
|
)(orig_t) |
|
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device) |
|
return f0_interp |
|
|
|
|
|
def naive_quantize(x, edges): |
|
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1) |
|
return bin_idx |
|
|
|
|
|
def load_wav(full_path): |
|
try: |
|
import soundfile as sf |
|
except ImportError: |
|
raise "Please install soundfile (`pip install SoundFile`)" |
|
data, sampling_rate = sf.read(full_path) |
|
return data, sampling_rate |
|
|
|
|
|
def parse_code(code_str, dictionary, append_eos): |
|
code, duration = torch.unique_consecutive( |
|
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True |
|
) |
|
code = " ".join(map(str, code.tolist())) |
|
code = dictionary.encode_line(code, append_eos).short() |
|
|
|
if append_eos: |
|
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) |
|
duration = duration.short() |
|
return code, duration |
|
|
|
|
|
def parse_manifest(manifest, dictionary): |
|
audio_files = [] |
|
codes = [] |
|
durations = [] |
|
speakers = [] |
|
|
|
with open(manifest) as info: |
|
for line in info.readlines(): |
|
sample = eval(line.strip()) |
|
if "cpc_km100" in sample: |
|
k = "cpc_km100" |
|
elif "hubert_km100" in sample: |
|
k = "hubert_km100" |
|
elif "phone" in sample: |
|
k = "phone" |
|
else: |
|
assert False, "unknown format" |
|
code = sample[k] |
|
code, duration = parse_code(code, dictionary, append_eos=True) |
|
|
|
codes.append(code) |
|
durations.append(duration) |
|
audio_files.append(sample["audio"]) |
|
speakers.append(sample.get("speaker", None)) |
|
|
|
return audio_files, codes, durations, speakers |
|
|
|
|
|
def parse_speaker(path, method): |
|
if type(path) == str: |
|
path = Path(path) |
|
|
|
if method == "parent_name": |
|
return path.parent.name |
|
elif method == "parent_parent_name": |
|
return path.parent.parent.name |
|
elif method == "_": |
|
return path.name.split("_")[0] |
|
elif method == "single": |
|
return "A" |
|
elif callable(method): |
|
return method(path) |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
def get_f0_by_filename(filename, tgt_sampling_rate): |
|
audio, sampling_rate = load_wav(filename) |
|
if sampling_rate != tgt_sampling_rate: |
|
raise ValueError( |
|
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate) |
|
) |
|
|
|
|
|
f0 = get_f0(audio, rate=tgt_sampling_rate) |
|
f0 = torch.from_numpy(f0.astype(np.float32)) |
|
return f0 |
|
|
|
|
|
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1): |
|
code_len = durations.sum() |
|
targ_len = int(f0_code_ratio * code_len) |
|
diff = f0.size(0) - targ_len |
|
assert abs(diff) <= tol, ( |
|
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|" |
|
f" > {tol} (dur=\n{durations})" |
|
) |
|
if diff > 0: |
|
f0 = f0[:targ_len] |
|
elif diff < 0: |
|
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0) |
|
|
|
f0_offset = 0.0 |
|
seg_f0s = [] |
|
for dur in durations: |
|
f0_dur = dur.item() * f0_code_ratio |
|
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)] |
|
seg_f0 = seg_f0[seg_f0 != 0] |
|
if len(seg_f0) == 0: |
|
seg_f0 = torch.tensor(0).type(seg_f0.type()) |
|
else: |
|
seg_f0 = seg_f0.mean() |
|
seg_f0s.append(seg_f0) |
|
f0_offset += f0_dur |
|
|
|
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}" |
|
return torch.tensor(seg_f0s) |
|
|
|
|
|
class Paddings(object): |
|
def __init__(self, code_val, dur_val=0, f0_val=-2.0): |
|
self.code = code_val |
|
self.dur = dur_val |
|
self.f0 = f0_val |
|
|
|
|
|
class Shifts(object): |
|
def __init__(self, shifts_str, pads): |
|
self._shifts = list(map(int, shifts_str.split(","))) |
|
assert len(self._shifts) == 2, self._shifts |
|
assert all(s >= 0 for s in self._shifts) |
|
self.extra_length = max(s for s in self._shifts) |
|
self.pads = pads |
|
|
|
@property |
|
def dur(self): |
|
return self._shifts[0] |
|
|
|
@property |
|
def f0(self): |
|
return self._shifts[1] |
|
|
|
@staticmethod |
|
def shift_one(seq, left_pad_num, right_pad_num, pad): |
|
assert seq.ndim == 1 |
|
bos = seq.new_full((left_pad_num,), pad) |
|
eos = seq.new_full((right_pad_num,), pad) |
|
seq = torch.cat([bos, seq, eos]) |
|
mask = torch.ones_like(seq).bool() |
|
mask[left_pad_num : len(seq) - right_pad_num] = 0 |
|
return seq, mask |
|
|
|
def __call__(self, code, dur, f0): |
|
if self.extra_length == 0: |
|
code_mask = torch.zeros_like(code).bool() |
|
dur_mask = torch.zeros_like(dur).bool() |
|
f0_mask = torch.zeros_like(f0).bool() |
|
return code, code_mask, dur, dur_mask, f0, f0_mask |
|
|
|
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code) |
|
dur, dur_mask = self.shift_one( |
|
dur, self.dur, self.extra_length - self.dur, self.pads.dur |
|
) |
|
f0, f0_mask = self.shift_one( |
|
f0, self.f0, self.extra_length - self.f0, self.pads.f0 |
|
) |
|
return code, code_mask, dur, dur_mask, f0, f0_mask |
|
|
|
|
|
class CodeDataset(FairseqDataset): |
|
def __init__( |
|
self, |
|
manifest, |
|
dictionary, |
|
dur_dictionary, |
|
f0_dictionary, |
|
config, |
|
discrete_dur, |
|
discrete_f0, |
|
log_f0, |
|
normalize_f0_mean, |
|
normalize_f0_std, |
|
interpolate_f0, |
|
return_filename=False, |
|
strip_filename=True, |
|
shifts="0,0", |
|
return_continuous_f0=False, |
|
): |
|
random.seed(1234) |
|
self.dictionary = dictionary |
|
self.dur_dictionary = dur_dictionary |
|
self.f0_dictionary = f0_dictionary |
|
self.config = config |
|
|
|
|
|
self.discrete_dur = discrete_dur |
|
|
|
|
|
self.discrete_f0 = discrete_f0 |
|
self.log_f0 = log_f0 |
|
self.normalize_f0_mean = normalize_f0_mean |
|
self.normalize_f0_std = normalize_f0_std |
|
self.interpolate_f0 = interpolate_f0 |
|
|
|
self.return_filename = return_filename |
|
self.strip_filename = strip_filename |
|
self.f0_code_ratio = config.code_hop_size / ( |
|
config.sampling_rate * F0_FRAME_SPACE |
|
) |
|
|
|
|
|
self.manifest = manifest |
|
self._codes = None |
|
self._durs = None |
|
self._f0s = None |
|
with open(f"{manifest}.leng.txt", "r") as f: |
|
lengs = [int(line.rstrip()) for line in f] |
|
edges = np.cumsum([0] + lengs) |
|
self.starts, self.ends = edges[:-1], edges[1:] |
|
with open(f"{manifest}.path.txt", "r") as f: |
|
self.file_names = [line.rstrip() for line in f] |
|
logger.info(f"num entries: {len(self.starts)}") |
|
|
|
if os.path.exists(f"{manifest}.f0_stat.pt"): |
|
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt") |
|
elif config.f0_stats: |
|
self.f0_stats = torch.load(config.f0_stats) |
|
|
|
self.multispkr = config.multispkr |
|
if config.multispkr: |
|
with open(f"{manifest}.speaker.txt", "r") as f: |
|
self.spkrs = [line.rstrip() for line in f] |
|
self.id_to_spkr = sorted(self.spkrs) |
|
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)} |
|
|
|
self.pads = Paddings( |
|
dictionary.pad(), |
|
0, |
|
f0_dictionary.pad() if discrete_f0 else -5.0, |
|
) |
|
self.shifts = Shifts(shifts, pads=self.pads) |
|
self.return_continuous_f0 = return_continuous_f0 |
|
|
|
def get_data_handlers(self): |
|
logging.info(f"loading data for {self.manifest}") |
|
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r") |
|
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r") |
|
|
|
if self.discrete_f0: |
|
if self.config.f0_vq_type == "precomp": |
|
self._f0s = np.load( |
|
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r" |
|
) |
|
elif self.config.f0_vq_type == "naive": |
|
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") |
|
quantizers_path = self.config.get_f0_vq_naive_quantizer( |
|
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std |
|
) |
|
quantizers = torch.load(quantizers_path) |
|
n_units = self.config.f0_vq_n_units |
|
self._f0_quantizer = torch.from_numpy(quantizers[n_units]) |
|
else: |
|
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported") |
|
else: |
|
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") |
|
|
|
def preprocess_f0(self, f0, stats): |
|
""" |
|
1. interpolate |
|
2. log transform (keep unvoiced frame 0) |
|
""" |
|
|
|
f0 = f0.clone() |
|
if self.interpolate_f0: |
|
f0 = interpolate_f0(f0) |
|
|
|
mask = f0 != 0 |
|
if self.log_f0: |
|
f0[mask] = f0[mask].log() |
|
if self.normalize_f0_mean: |
|
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"] |
|
f0[mask] = f0[mask] - mean |
|
if self.normalize_f0_std: |
|
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"] |
|
f0[mask] = f0[mask] / std |
|
return f0 |
|
|
|
def _get_raw_item(self, index): |
|
start, end = self.starts[index], self.ends[index] |
|
if self._codes is None: |
|
self.get_data_handlers() |
|
code = torch.from_numpy(np.array(self._codes[start:end])).long() |
|
dur = torch.from_numpy(np.array(self._durs[start:end])) |
|
f0 = torch.from_numpy(np.array(self._f0s[start:end])) |
|
return code, dur, f0 |
|
|
|
def __getitem__(self, index): |
|
code, dur, f0 = self._get_raw_item(index) |
|
code = torch.cat([code.new([self.dictionary.bos()]), code]) |
|
|
|
|
|
dur = torch.cat([dur.new([0]), dur]) |
|
if self.discrete_dur: |
|
dur = self.dur_dictionary.encode_line( |
|
" ".join(map(str, dur.tolist())), append_eos=False |
|
).long() |
|
else: |
|
dur = dur.float() |
|
|
|
|
|
raw_f0 = None |
|
if self.discrete_f0: |
|
if self.config.f0_vq_type == "precomp": |
|
f0 = self.f0_dictionary.encode_line( |
|
" ".join(map(str, f0.tolist())), append_eos=False |
|
).long() |
|
else: |
|
f0 = f0.float() |
|
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) |
|
if self.return_continuous_f0: |
|
raw_f0 = f0 |
|
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0]) |
|
f0 = naive_quantize(f0, self._f0_quantizer) |
|
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0]) |
|
else: |
|
f0 = f0.float() |
|
if self.multispkr: |
|
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) |
|
else: |
|
f0 = self.preprocess_f0(f0, self.f0_stats) |
|
f0 = torch.cat([f0.new([0]), f0]) |
|
|
|
if raw_f0 is not None: |
|
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0) |
|
else: |
|
raw_f0_mask = None |
|
|
|
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0) |
|
if raw_f0_mask is not None: |
|
assert (raw_f0_mask == f0_mask).all() |
|
|
|
|
|
feats = { |
|
"source": code[:-1], |
|
"target": code[1:], |
|
"mask": code_mask[1:].logical_or(code_mask[:-1]), |
|
"dur_source": dur[:-1], |
|
"dur_target": dur[1:], |
|
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]), |
|
"f0_source": f0[:-1], |
|
"f0_target": f0[1:], |
|
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]), |
|
} |
|
|
|
if raw_f0 is not None: |
|
feats["raw_f0"] = raw_f0[1:] |
|
|
|
if self.return_filename: |
|
fname = self.file_names[index] |
|
feats["filename"] = ( |
|
fname if not self.strip_filename else Path(fname).with_suffix("").name |
|
) |
|
return feats |
|
|
|
def __len__(self): |
|
return len(self.starts) |
|
|
|
def size(self, index): |
|
return self.ends[index] - self.starts[index] + self.shifts.extra_length |
|
|
|
def num_tokens(self, index): |
|
return self.size(index) |
|
|
|
def collater(self, samples): |
|
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos() |
|
if len(samples) == 0: |
|
return {} |
|
|
|
src_tokens = data_utils.collate_tokens( |
|
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False |
|
) |
|
|
|
tgt_tokens = data_utils.collate_tokens( |
|
[s["target"] for s in samples], |
|
pad_idx=pad_idx, |
|
eos_idx=pad_idx, |
|
left_pad=False, |
|
) |
|
|
|
src_durs, tgt_durs = [ |
|
data_utils.collate_tokens( |
|
[s[k] for s in samples], |
|
pad_idx=self.pads.dur, |
|
eos_idx=self.pads.dur, |
|
left_pad=False, |
|
) |
|
for k in ["dur_source", "dur_target"] |
|
] |
|
|
|
src_f0s, tgt_f0s = [ |
|
data_utils.collate_tokens( |
|
[s[k] for s in samples], |
|
pad_idx=self.pads.f0, |
|
eos_idx=self.pads.f0, |
|
left_pad=False, |
|
) |
|
for k in ["f0_source", "f0_target"] |
|
] |
|
|
|
mask, dur_mask, f0_mask = [ |
|
data_utils.collate_tokens( |
|
[s[k] for s in samples], |
|
pad_idx=1, |
|
eos_idx=1, |
|
left_pad=False, |
|
) |
|
for k in ["mask", "dur_mask", "f0_mask"] |
|
] |
|
|
|
src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) |
|
n_tokens = sum(len(s["source"]) for s in samples) |
|
|
|
result = { |
|
"nsentences": len(samples), |
|
"ntokens": n_tokens, |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"src_lengths": src_lengths, |
|
"dur_src": src_durs, |
|
"f0_src": src_f0s, |
|
}, |
|
"target": tgt_tokens, |
|
"dur_target": tgt_durs, |
|
"f0_target": tgt_f0s, |
|
"mask": mask, |
|
"dur_mask": dur_mask, |
|
"f0_mask": f0_mask, |
|
} |
|
|
|
if "filename" in samples[0]: |
|
result["filename"] = [s["filename"] for s in samples] |
|
|
|
|
|
if "prefix" in samples[0]: |
|
result["prefix"] = [s["prefix"] for s in samples] |
|
|
|
if "raw_f0" in samples[0]: |
|
raw_f0s = data_utils.collate_tokens( |
|
[s["raw_f0"] for s in samples], |
|
pad_idx=self.pads.f0, |
|
eos_idx=self.pads.f0, |
|
left_pad=False, |
|
) |
|
result["raw_f0"] = raw_f0s |
|
return result |
|
|