ML2TransformerApp / train.py
dkoshman
moved constants to separate file, organized tokenizer
c308f77
raw
history blame
No virus
4.55 kB
from constants import TRAINER_DIR
from data_preprocessing import LatexImageDataModule
from model import Transformer
from utils import LogImageTexCallback
import argparse
import os
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
# crop image, adjust brightness, make tex tokens always decodable,
# save only datamodule state?, ensemble last checkpoints, early stopping
def parse_args():
parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs")
parser.add_argument("-g", "-gpus", metavar="GPUS", type=int, choices=list(range(torch.cuda.device_count())),
help="ids of gpus to train on, if not provided, then trains on cpu", nargs="+", dest="gpus")
parser.add_argument("-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
action="store_true", dest="log")
parser.add_argument("-width", help="width of images, default 1024", default=1024, type=int)
parser.add_argument("-height", help="height of images, default 128", default=128, type=int)
parser.add_argument("-r", "-randomize", default=5, type=int, dest="random_magnitude", choices=range(10),
help="add random augments to images of provided magnitude in range 0..9, default 5")
parser.add_argument("-b", "-batch-size", help="batch size, default 16", default=16,
type=int, dest="batch_size")
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
("dim_feedforward", 2048), ("dropout", 0.1)]
parser.add_argument("-t", "-transformer-args", dest="transformer_args", nargs='+', default=[],
help="transformer init args:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
args = parser.parse_args()
for i, parameter in enumerate(args.transformer_args):
transformer_args[i][1] = parameter
args.transformer_args = dict(transformer_args)
return args
def main():
args = parse_args()
datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
batch_size=args.batch_size, random_magnitude=args.random_magnitude)
datamodule.prepare_data()
if args.log:
logger = WandbLogger(f"img2tex", log_model=True)
callbacks = [LogImageTexCallback(logger),
LearningRateMonitor(logging_interval='step'),
ModelCheckpoint(save_top_k=10,
monitor="val_loss",
mode="min",
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
else:
logger = None
callbacks = []
trainer = Trainer(max_epochs=args.max_epochs,
accelerator="cpu" if args.gpus is None else "gpu",
gpus=args.gpus,
logger=logger,
strategy="ddp",
enable_progress_bar=True,
default_root_dir=TRAINER_DIR,
callbacks=callbacks,
check_val_every_n_epoch=5)
transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
num_decoder_layers=args.transformer_args['num_decoder_layers'],
d_model=args.transformer_args['d_model'],
nhead=args.transformer_args['nhead'],
dim_feedforward=args.transformer_args['dim_feedforward'],
dropout=args.transformer_args['dropout'],
image_width=datamodule.hparams['image_width'],
image_height=datamodule.hparams['image_height'],
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
trainer.fit(transformer, datamodule=datamodule)
trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
if __name__ == "__main__":
main()