File size: 2,077 Bytes
3b92d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import json
import argparse
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

from data_utils import TextMelLoader, TextMelCollate
import models
import commons
import utils


class FlowGenerator_DDI(models.FlowGenerator):
    """A helper for Data-dependent Initialization"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for f in self.decoder.flows:
            if getattr(f, "set_ddi", False):
                f.set_ddi(True)


def main():
    hps = utils.get_hparams()
    logger = utils.get_logger(hps.log_dir)
    logger.info(hps)
    utils.check_git_hash(hps.log_dir)

    torch.manual_seed(hps.train.seed)

    train_dataset = TextMelLoader(hps.data.training_files, hps.data)
    collate_fn = TextMelCollate(1)
    train_loader = DataLoader(
        train_dataset,
        num_workers=8,
        shuffle=True,
        batch_size=hps.train.batch_size,
        pin_memory=True,
        drop_last=True,
        collate_fn=collate_fn,
    )
    symbols = hps.data.punc + hps.data.chars
    generator = FlowGenerator_DDI(
        len(symbols) + getattr(hps.data, "add_blank", False),
        out_channels=hps.data.n_mel_channels,
        **hps.model
    ).cuda()
    optimizer_g = commons.Adam(
        generator.parameters(),
        scheduler=hps.train.scheduler,
        dim_model=hps.model.hidden_channels,
        warmup_steps=hps.train.warmup_steps,
        lr=hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps,
    )

    generator.train()
    for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader):
        x, x_lengths = x.cuda(), x_lengths.cuda()
        y, y_lengths = y.cuda(), y_lengths.cuda()

        _ = generator(x, x_lengths, y, y_lengths, gen=False)
        break

    utils.save_checkpoint(
        generator,
        optimizer_g,
        hps.train.learning_rate,
        0,
        os.path.join(hps.model_dir, "ddi_G.pth"),
    )


if __name__ == "__main__":
    main()