from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH from data_preprocessing import LatexImageDataModule from model import Transformer from utils import LogImageTexCallback import argparse import os from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning import Trainer import torch # TODO: update python, make tex tokens always decodable, ensemble last checkpoints, # clear checkpoint data build full dataset, train export model to torchscript write spaces interface def check_setup(): os.environ["TOKENIZERS_PARALLELISM"] = "false" if not os.path.isfile(DATAMODULE_PATH): datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5) torch.save(datamodule, DATAMODULE_PATH) if not os.path.isfile(TOKENIZER_PATH): datamodule = torch.load(DATAMODULE_PATH) datamodule.train_tokenizer() def parse_args(): parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("gpus", type=int, default=None, help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, " "if not provided, then trains on cpu", nargs="*") parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False, action="store_true", dest="log") parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs") datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"] datamodule = torch.load(DATAMODULE_PATH) parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int, help="Create new datamodule and exit, current parameters:\n" + "\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args)) transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8), ("dim_feedforward", 2048), ("dropout", 0.1)] parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args), help="Transformer init args, reference values:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args)) args = parser.parse_args() if args.datamodule_args: args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args)) if args.transformer_args: args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args)) else: args.transformer_args = dict(transformer_args) return args def main(): check_setup() args = parse_args() if args.datamodule_args: datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"], image_height=args.datamodule_args["image_height"], batch_size=args.datamodule_args["batch_size"], random_magnitude=args.datamodule_args["random_magnitude"]) datamodule.train_tokenizer() tex_tokenizer = torch.load(TOKENIZER_PATH) print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}") torch.save(datamodule, DATAMODULE_PATH) return datamodule = torch.load(DATAMODULE_PATH) tex_tokenizer = torch.load(TOKENIZER_PATH) logger = None callbacks = [] if args.log: logger = WandbLogger(f"img2tex", log_model=True) callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20), LearningRateMonitor(logging_interval="step"), ModelCheckpoint(save_top_k=10, monitor="val_loss", mode="min", filename="img2tex-{epoch:02d}-{val_loss:.2f}")] trainer = Trainer(max_epochs=args.max_epochs, accelerator="cpu" if args.gpus is None else "gpu", gpus=args.gpus, logger=logger, strategy="ddp_find_unused_parameters_false", enable_progress_bar=True, default_root_dir=TRAINER_DIR, callbacks=callbacks, check_val_every_n_epoch=5) transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"], num_decoder_layers=args.transformer_args["num_decoder_layers"], d_model=args.transformer_args["d_model"], nhead=args.transformer_args["nhead"], dim_feedforward=args.transformer_args["dim_feedforward"], dropout=args.transformer_args["dropout"], image_width=datamodule.hparams["image_width"], image_height=datamodule.hparams["image_height"], tgt_vocab_size=tex_tokenizer.get_vocab_size(), pad_idx=tex_tokenizer.token_to_id("[PAD]")) trainer.fit(transformer, datamodule=datamodule) trainer.test(transformer, datamodule=datamodule) if __name__ == "__main__": main()