File size: 5,501 Bytes
29bcc5f
41a34cd
e932abd
4f4785c
6e82d4a
fb8db0f
41a34cd
c308f77
e932abd
41a34cd
6e82d4a
ae308b4
 
29bcc5f
 
 
 
 
 
 
 
 
 
 
 
41a34cd
c2ef1c6
41a34cd
 
29bcc5f
 
 
 
41a34cd
29bcc5f
 
 
 
 
 
 
 
 
41a34cd
 
29bcc5f
 
 
c7f2652
fb8db0f
29bcc5f
 
 
 
 
 
 
 
 
c2ef1c6
 
 
 
29bcc5f
c2ef1c6
29bcc5f
 
 
 
 
 
 
 
 
 
c308f77
29bcc5f
 
 
 
4f4785c
 
29bcc5f
 
c308f77
 
 
 
c2ef1c6
 
4f4785c
c2ef1c6
 
29bcc5f
4f4785c
 
c308f77
 
41a34cd
29bcc5f
 
 
 
 
 
 
 
 
 
c2ef1c6
fb8db0f
29bcc5f
fb8db0f
 
4f4785c
fb8db0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
from data_preprocessing import LatexImageDataModule
from model import Transformer
from utils import LogImageTexCallback

import argparse
import os
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch


# TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
#  clear checkpoint data build full dataset, train export model to torchscript write spaces interface

def check_setup():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    if not os.path.isfile(DATAMODULE_PATH):
        datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
        torch.save(datamodule, DATAMODULE_PATH)
    if not os.path.isfile(TOKENIZER_PATH):
        datamodule = torch.load(DATAMODULE_PATH)
        datamodule.train_tokenizer()


def parse_args():
    parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument("gpus", type=int, default=None,
                        help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, "
                             "if not provided, then trains on cpu", nargs="*")
    parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
                        action="store_true", dest="log")
    parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")

    datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]

    datamodule = torch.load(DATAMODULE_PATH)
    parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int,
                        help="Create new datamodule and exit, current parameters:\n" +
                             "\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))

    transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
                        ("dim_feedforward", 2048), ("dropout", 0.1)]
    parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args),
                        help="Transformer init args, reference values:\n" +
                             "\n".join(f"{k}\t{v}" for k, v in transformer_args))

    args = parser.parse_args()

    if args.datamodule_args:
        args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))

    if args.transformer_args:
        args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args))
    else:
        args.transformer_args = dict(transformer_args)

    return args


def main():
    check_setup()
    args = parse_args()
    if args.datamodule_args:
        datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
                                          image_height=args.datamodule_args["image_height"],
                                          batch_size=args.datamodule_args["batch_size"],
                                          random_magnitude=args.datamodule_args["random_magnitude"])
        datamodule.train_tokenizer()
        tex_tokenizer = torch.load(TOKENIZER_PATH)
        print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}")
        torch.save(datamodule, DATAMODULE_PATH)
        return

    datamodule = torch.load(DATAMODULE_PATH)
    tex_tokenizer = torch.load(TOKENIZER_PATH)
    logger = None
    callbacks = []
    if args.log:
        logger = WandbLogger(f"img2tex", log_model=True)
        callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20),
                     LearningRateMonitor(logging_interval="step"),
                     ModelCheckpoint(save_top_k=10,
                                     monitor="val_loss",
                                     mode="min",
                                     filename="img2tex-{epoch:02d}-{val_loss:.2f}")]

    trainer = Trainer(max_epochs=args.max_epochs,
                      accelerator="cpu" if args.gpus is None else "gpu",
                      gpus=args.gpus,
                      logger=logger,
                      strategy="ddp_find_unused_parameters_false",
                      enable_progress_bar=True,
                      default_root_dir=TRAINER_DIR,
                      callbacks=callbacks,
                      check_val_every_n_epoch=5)

    transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
                              num_decoder_layers=args.transformer_args["num_decoder_layers"],
                              d_model=args.transformer_args["d_model"],
                              nhead=args.transformer_args["nhead"],
                              dim_feedforward=args.transformer_args["dim_feedforward"],
                              dropout=args.transformer_args["dropout"],
                              image_width=datamodule.hparams["image_width"],
                              image_height=datamodule.hparams["image_height"],
                              tgt_vocab_size=tex_tokenizer.get_vocab_size(),
                              pad_idx=tex_tokenizer.token_to_id("[PAD]"))

    trainer.fit(transformer, datamodule=datamodule)
    trainer.test(transformer, datamodule=datamodule)


if __name__ == "__main__":
    main()