Spaces:
Runtime error
Runtime error
File size: 5,501 Bytes
29bcc5f 41a34cd e932abd 4f4785c 6e82d4a fb8db0f 41a34cd c308f77 e932abd 41a34cd 6e82d4a ae308b4 29bcc5f 41a34cd c2ef1c6 41a34cd 29bcc5f 41a34cd 29bcc5f 41a34cd 29bcc5f c7f2652 fb8db0f 29bcc5f c2ef1c6 29bcc5f c2ef1c6 29bcc5f c308f77 29bcc5f 4f4785c 29bcc5f c308f77 c2ef1c6 4f4785c c2ef1c6 29bcc5f 4f4785c c308f77 41a34cd 29bcc5f c2ef1c6 fb8db0f 29bcc5f fb8db0f 4f4785c fb8db0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
from data_preprocessing import LatexImageDataModule
from model import Transformer
from utils import LogImageTexCallback
import argparse
import os
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
# TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
# clear checkpoint data build full dataset, train export model to torchscript write spaces interface
def check_setup():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if not os.path.isfile(DATAMODULE_PATH):
datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
torch.save(datamodule, DATAMODULE_PATH)
if not os.path.isfile(TOKENIZER_PATH):
datamodule = torch.load(DATAMODULE_PATH)
datamodule.train_tokenizer()
def parse_args():
parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("gpus", type=int, default=None,
help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, "
"if not provided, then trains on cpu", nargs="*")
parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
action="store_true", dest="log")
parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
datamodule = torch.load(DATAMODULE_PATH)
parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int,
help="Create new datamodule and exit, current parameters:\n" +
"\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
("dim_feedforward", 2048), ("dropout", 0.1)]
parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args),
help="Transformer init args, reference values:\n" +
"\n".join(f"{k}\t{v}" for k, v in transformer_args))
args = parser.parse_args()
if args.datamodule_args:
args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
if args.transformer_args:
args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args))
else:
args.transformer_args = dict(transformer_args)
return args
def main():
check_setup()
args = parse_args()
if args.datamodule_args:
datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
image_height=args.datamodule_args["image_height"],
batch_size=args.datamodule_args["batch_size"],
random_magnitude=args.datamodule_args["random_magnitude"])
datamodule.train_tokenizer()
tex_tokenizer = torch.load(TOKENIZER_PATH)
print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}")
torch.save(datamodule, DATAMODULE_PATH)
return
datamodule = torch.load(DATAMODULE_PATH)
tex_tokenizer = torch.load(TOKENIZER_PATH)
logger = None
callbacks = []
if args.log:
logger = WandbLogger(f"img2tex", log_model=True)
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20),
LearningRateMonitor(logging_interval="step"),
ModelCheckpoint(save_top_k=10,
monitor="val_loss",
mode="min",
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
trainer = Trainer(max_epochs=args.max_epochs,
accelerator="cpu" if args.gpus is None else "gpu",
gpus=args.gpus,
logger=logger,
strategy="ddp_find_unused_parameters_false",
enable_progress_bar=True,
default_root_dir=TRAINER_DIR,
callbacks=callbacks,
check_val_every_n_epoch=5)
transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
num_decoder_layers=args.transformer_args["num_decoder_layers"],
d_model=args.transformer_args["d_model"],
nhead=args.transformer_args["nhead"],
dim_feedforward=args.transformer_args["dim_feedforward"],
dropout=args.transformer_args["dropout"],
image_width=datamodule.hparams["image_width"],
image_height=datamodule.hparams["image_height"],
tgt_vocab_size=tex_tokenizer.get_vocab_size(),
pad_idx=tex_tokenizer.token_to_id("[PAD]"))
trainer.fit(transformer, datamodule=datamodule)
trainer.test(transformer, datamodule=datamodule)
if __name__ == "__main__":
main()
|