harveen
Add Tamil
3b92d66
raw history blame
No virus
2.08 kB
import os
import json
import argparse
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from data_utils import TextMelLoader, TextMelCollate
import models
import commons
import utils
class FlowGenerator_DDI(models.FlowGenerator):
"""A helper for Data-dependent Initialization"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for f in self.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
def main():
hps = utils.get_hparams()
logger = utils.get_logger(hps.log_dir)
logger.info(hps)
utils.check_git_hash(hps.log_dir)
torch.manual_seed(hps.train.seed)
train_dataset = TextMelLoader(hps.data.training_files, hps.data)
collate_fn = TextMelCollate(1)
train_loader = DataLoader(
train_dataset,
num_workers=8,
shuffle=True,
batch_size=hps.train.batch_size,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
symbols = hps.data.punc + hps.data.chars
generator = FlowGenerator_DDI(
len(symbols) + getattr(hps.data, "add_blank", False),
out_channels=hps.data.n_mel_channels,
**hps.model
).cuda()
optimizer_g = commons.Adam(
generator.parameters(),
scheduler=hps.train.scheduler,
dim_model=hps.model.hidden_channels,
warmup_steps=hps.train.warmup_steps,
lr=hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
generator.train()
for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader):
x, x_lengths = x.cuda(), x_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
_ = generator(x, x_lengths, y, y_lengths, gen=False)
break
utils.save_checkpoint(
generator,
optimizer_g,
hps.train.learning_rate,
0,
os.path.join(hps.model_dir, "ddi_G.pth"),
)
if __name__ == "__main__":
main()