|
|
|
|
|
"""TTS model AR prior training.""" |
|
|
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
import sys |
|
import time |
|
from typing import Optional |
|
from typing import Sequence |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
from typeguard import check_argument_types |
|
|
|
from espnet.utils.cli_utils import get_commandline_args |
|
from espnet2.tasks.tts import TTSTask |
|
from espnet2.torch_utils.device_funcs import to_device |
|
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
|
from espnet2.tts.duration_calculator import DurationCalculator |
|
from espnet2.tts.fastspeech import FastSpeech |
|
from espnet2.tts.fastspeech2 import FastSpeech2 |
|
from espnet2.tts.fastespeech import FastESpeech |
|
from espnet2.tts.tacotron2 import Tacotron2 |
|
from espnet2.tts.transformer import Transformer |
|
from espnet2.utils import config_argparse |
|
from espnet2.utils.get_default_kwargs import get_default_kwargs |
|
from espnet2.utils.griffin_lim import Spectrogram2Waveform |
|
from espnet2.utils.nested_dict_action import NestedDictAction |
|
from espnet2.utils.types import str2bool |
|
from espnet2.utils.types import str2triple_str |
|
from espnet2.utils.types import str_or_none |
|
|
|
from espnet2.tts.prosody_encoder import ARPrior |
|
|
|
import torch.optim as optim |
|
|
|
|
|
class Text2Speech: |
|
"""Speech2Text class |
|
""" |
|
|
|
def __init__( |
|
self, |
|
train_config: Optional[Union[Path, str]], |
|
model_file: Optional[Union[Path, str]] = None, |
|
threshold: float = 0.5, |
|
minlenratio: float = 0.0, |
|
maxlenratio: float = 10.0, |
|
use_teacher_forcing: bool = False, |
|
use_att_constraint: bool = False, |
|
backward_window: int = 1, |
|
forward_window: int = 3, |
|
speed_control_alpha: float = 1.0, |
|
vocoder_conf: dict = None, |
|
dtype: str = "float32", |
|
device: str = "cpu", |
|
): |
|
assert check_argument_types() |
|
|
|
model, train_args = TTSTask.build_model_from_file( |
|
train_config, model_file, device |
|
) |
|
model.to(dtype=getattr(torch, dtype)).eval() |
|
self.device = device |
|
self.dtype = dtype |
|
self.train_args = train_args |
|
self.model = model |
|
self.tts = model.tts |
|
self.normalize = model.normalize |
|
self.feats_extract = model.feats_extract |
|
self.duration_calculator = DurationCalculator() |
|
self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) |
|
self.use_teacher_forcing = use_teacher_forcing |
|
|
|
logging.info(f"Normalization:\n{self.normalize}") |
|
logging.info(f"TTS:\n{self.tts}") |
|
|
|
decode_config = {} |
|
if isinstance(self.tts, (Tacotron2, Transformer)): |
|
decode_config.update( |
|
{ |
|
"threshold": threshold, |
|
"maxlenratio": maxlenratio, |
|
"minlenratio": minlenratio, |
|
} |
|
) |
|
if isinstance(self.tts, Tacotron2): |
|
decode_config.update( |
|
{ |
|
"use_att_constraint": use_att_constraint, |
|
"forward_window": forward_window, |
|
"backward_window": backward_window, |
|
} |
|
) |
|
if isinstance(self.tts, (FastSpeech, FastSpeech2, FastESpeech)): |
|
decode_config.update({"alpha": speed_control_alpha}) |
|
decode_config.update({"use_teacher_forcing": use_teacher_forcing}) |
|
|
|
self.decode_config = decode_config |
|
|
|
if vocoder_conf is None: |
|
vocoder_conf = {} |
|
if self.feats_extract is not None: |
|
vocoder_conf.update(self.feats_extract.get_parameters()) |
|
if ( |
|
"n_fft" in vocoder_conf |
|
and "n_shift" in vocoder_conf |
|
and "fs" in vocoder_conf |
|
): |
|
self.spc2wav = Spectrogram2Waveform(**vocoder_conf) |
|
logging.info(f"Vocoder: {self.spc2wav}") |
|
else: |
|
self.spc2wav = None |
|
logging.info("Vocoder is not used because vocoder_conf is not sufficient") |
|
|
|
def __call__( |
|
self, |
|
text: Union[str, torch.Tensor, np.ndarray], |
|
speech: Union[torch.Tensor, np.ndarray] = None, |
|
durations: Union[torch.Tensor, np.ndarray] = None, |
|
ref_embs: torch.Tensor = None, |
|
): |
|
assert check_argument_types() |
|
|
|
if self.use_speech and speech is None: |
|
raise RuntimeError("missing required argument: 'speech'") |
|
|
|
if isinstance(text, str): |
|
|
|
text = self.preprocess_fn("<dummy>", {"text": text})["text"] |
|
batch = {"text": text, "ref_embs": ref_embs} |
|
if speech is not None: |
|
batch["speech"] = speech |
|
if durations is not None: |
|
batch["durations"] = durations |
|
|
|
batch = to_device(batch, self.device) |
|
outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss = \ |
|
self.model.inference(**batch, **self.decode_config, train_ar_prior=True) |
|
|
|
return ar_prior_loss |
|
|
|
@property |
|
def fs(self) -> Optional[int]: |
|
if self.spc2wav is not None: |
|
return self.spc2wav.fs |
|
else: |
|
return None |
|
|
|
@property |
|
def use_speech(self) -> bool: |
|
"""Check whether to require speech in inference. |
|
|
|
Returns: |
|
bool: True if speech is required else False. |
|
|
|
""" |
|
|
|
return self.use_teacher_forcing or getattr(self.tts, "use_gst", True) |
|
|
|
|
|
def train_prior( |
|
output_dir: str, |
|
batch_size: int, |
|
dtype: str, |
|
ngpu: int, |
|
seed: int, |
|
num_workers: int, |
|
log_level: Union[int, str], |
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
|
key_file: Optional[str], |
|
train_config: Optional[str], |
|
model_file: Optional[str], |
|
threshold: float, |
|
minlenratio: float, |
|
maxlenratio: float, |
|
use_teacher_forcing: bool, |
|
use_att_constraint: bool, |
|
backward_window: int, |
|
forward_window: int, |
|
speed_control_alpha: float, |
|
allow_variable_data_keys: bool, |
|
vocoder_conf: dict, |
|
): |
|
"""Perform AR prior training.""" |
|
assert check_argument_types() |
|
if batch_size > 1: |
|
raise NotImplementedError("batch AR prior training is not implemented") |
|
if ngpu > 1: |
|
raise NotImplementedError("only single GPU AR prior training is supported") |
|
logging.basicConfig( |
|
level=log_level, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
|
|
if ngpu >= 1: |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
set_all_random_seed(seed) |
|
|
|
|
|
text2speech = Text2Speech( |
|
train_config=train_config, |
|
model_file=model_file, |
|
threshold=threshold, |
|
maxlenratio=maxlenratio, |
|
minlenratio=minlenratio, |
|
use_teacher_forcing=use_teacher_forcing, |
|
use_att_constraint=use_att_constraint, |
|
backward_window=backward_window, |
|
forward_window=forward_window, |
|
speed_control_alpha=speed_control_alpha, |
|
vocoder_conf=vocoder_conf, |
|
dtype=dtype, |
|
device=device, |
|
) |
|
|
|
|
|
if not text2speech.use_speech: |
|
data_path_and_name_and_type = list( |
|
filter(lambda x: x[1] != "speech", data_path_and_name_and_type) |
|
) |
|
loader = TTSTask.build_streaming_iterator( |
|
data_path_and_name_and_type, |
|
dtype=dtype, |
|
batch_size=batch_size, |
|
key_file=key_file, |
|
num_workers=num_workers, |
|
preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), |
|
collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), |
|
allow_variable_data_keys=allow_variable_data_keys, |
|
inference=True, |
|
) |
|
|
|
num_epochs = 500 |
|
|
|
|
|
for param in text2speech.model.parameters(): |
|
param.requires_grad = False |
|
|
|
text2speech.model.tts.prosody_encoder.ar_prior = ARPrior( |
|
num_embeddings=32, |
|
embedding_dim=384, |
|
lstm_num_layers=1, |
|
lstm_bidirectional=False, |
|
) |
|
|
|
text2speech.model.tts = text2speech.model.tts.to(device) |
|
|
|
optimizer = optim.SGD(text2speech.model.tts.parameters(), lr=0.001, momentum=0.9) |
|
|
|
since = time.time() |
|
|
|
for epoch in range(num_epochs): |
|
print('Epoch {}/{}'.format(epoch, num_epochs - 1)) |
|
print('-' * 10) |
|
|
|
|
|
for phase in ['train']: |
|
if phase == 'train': |
|
text2speech.model.tts.train() |
|
else: |
|
text2speech.model.tts.eval() |
|
|
|
for idx, (keys, batch) in enumerate(loader, 1): |
|
assert isinstance(batch, dict), type(batch) |
|
assert all(isinstance(s, str) for s in keys), keys |
|
_bs = len(next(iter(batch.values()))) |
|
assert _bs == 1, _bs |
|
|
|
|
|
|
|
batch = { |
|
k: v[0] for k, v in batch.items() if not k.endswith("_lengths") |
|
} |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
with torch.set_grad_enabled(phase == 'train'): |
|
loss = text2speech(**batch) |
|
|
|
|
|
if phase == 'train': |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print('Loss: {:.4f}'.format(loss)) |
|
|
|
if epoch % 10 == 0: |
|
torch.save(text2speech.model.state_dict(), "exp/tts_train_raw_phn_none/with_prior_" + str(epoch) + ".pth") |
|
|
|
time_elapsed = time.time() - since |
|
print('Training complete in {:.0f}m {:.0f}s'.format( |
|
time_elapsed // 60, time_elapsed % 60)) |
|
|
|
torch.save(text2speech.model.state_dict(), "exp/tts_train_raw_phn_none/with_prior.pth") |
|
|
|
|
|
def get_parser(): |
|
"""Get argument parser.""" |
|
parser = config_argparse.ArgumentParser( |
|
description="TTS Decode", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
|
|
|
|
parser.add_argument( |
|
"--log_level", |
|
type=lambda x: x.upper(), |
|
default="INFO", |
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
|
help="The verbose level of logging", |
|
) |
|
|
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
required=True, |
|
help="The path of output directory", |
|
) |
|
parser.add_argument( |
|
"--ngpu", |
|
type=int, |
|
default=0, |
|
help="The number of gpus. 0 indicates CPU mode", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=0, |
|
help="Random seed", |
|
) |
|
parser.add_argument( |
|
"--dtype", |
|
default="float32", |
|
choices=["float16", "float32", "float64"], |
|
help="Data type", |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=1, |
|
help="The number of workers used for DataLoader", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=1, |
|
help="The batch size for inference", |
|
) |
|
|
|
group = parser.add_argument_group("Input data related") |
|
group.add_argument( |
|
"--data_path_and_name_and_type", |
|
type=str2triple_str, |
|
required=True, |
|
action="append", |
|
) |
|
group.add_argument( |
|
"--key_file", |
|
type=str_or_none, |
|
) |
|
group.add_argument( |
|
"--allow_variable_data_keys", |
|
type=str2bool, |
|
default=False, |
|
) |
|
|
|
group = parser.add_argument_group("The model configuration related") |
|
group.add_argument( |
|
"--train_config", |
|
type=str, |
|
help="Training configuration file.", |
|
) |
|
group.add_argument( |
|
"--model_file", |
|
type=str, |
|
help="Model parameter file.", |
|
) |
|
|
|
group = parser.add_argument_group("Decoding related") |
|
group.add_argument( |
|
"--maxlenratio", |
|
type=float, |
|
default=10.0, |
|
help="Maximum length ratio in decoding", |
|
) |
|
group.add_argument( |
|
"--minlenratio", |
|
type=float, |
|
default=0.0, |
|
help="Minimum length ratio in decoding", |
|
) |
|
group.add_argument( |
|
"--threshold", |
|
type=float, |
|
default=0.5, |
|
help="Threshold value in decoding", |
|
) |
|
group.add_argument( |
|
"--use_att_constraint", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use attention constraint", |
|
) |
|
group.add_argument( |
|
"--backward_window", |
|
type=int, |
|
default=1, |
|
help="Backward window value in attention constraint", |
|
) |
|
group.add_argument( |
|
"--forward_window", |
|
type=int, |
|
default=3, |
|
help="Forward window value in attention constraint", |
|
) |
|
group.add_argument( |
|
"--use_teacher_forcing", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use teacher forcing", |
|
) |
|
parser.add_argument( |
|
"--speed_control_alpha", |
|
type=float, |
|
default=1.0, |
|
help="Alpha in FastSpeech to change the speed of generated speech", |
|
) |
|
|
|
group = parser.add_argument_group("Grriffin-Lim related") |
|
group.add_argument( |
|
"--vocoder_conf", |
|
action=NestedDictAction, |
|
default=get_default_kwargs(Spectrogram2Waveform), |
|
help="The configuration for Grriffin-Lim", |
|
) |
|
return parser |
|
|
|
|
|
def main(cmd=None): |
|
"""Run TTS model decoding.""" |
|
print(get_commandline_args(), file=sys.stderr) |
|
parser = get_parser() |
|
args = parser.parse_args(cmd) |
|
kwargs = vars(args) |
|
kwargs.pop("config", None) |
|
train_prior(**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|