harveen
Add Telugu
4bf2934
raw
history blame
2.08 kB
import os
import json
import argparse
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from data_utils import TextMelLoader, TextMelCollate
import models
import commons
import utils
class FlowGenerator_DDI(models.FlowGenerator):
"""A helper for Data-dependent Initialization"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for f in self.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
def main():
hps = utils.get_hparams()
logger = utils.get_logger(hps.log_dir)
logger.info(hps)
utils.check_git_hash(hps.log_dir)
torch.manual_seed(hps.train.seed)
train_dataset = TextMelLoader(hps.data.training_files, hps.data)
collate_fn = TextMelCollate(1)
train_loader = DataLoader(
train_dataset,
num_workers=8,
shuffle=True,
batch_size=hps.train.batch_size,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
symbols = hps.data.punc + hps.data.chars
generator = FlowGenerator_DDI(
len(symbols) + getattr(hps.data, "add_blank", False),
out_channels=hps.data.n_mel_channels,
**hps.model
).cuda()
optimizer_g = commons.Adam(
generator.parameters(),
scheduler=hps.train.scheduler,
dim_model=hps.model.hidden_channels,
warmup_steps=hps.train.warmup_steps,
lr=hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
generator.train()
for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader):
x, x_lengths = x.cuda(), x_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
_ = generator(x, x_lengths, y, y_lengths, gen=False)
break
utils.save_checkpoint(
generator,
optimizer_g,
hps.train.learning_rate,
0,
os.path.join(hps.model_dir, "ddi_G.pth"),
)
if __name__ == "__main__":
main()