Spaces:
Runtime error
Runtime error
File size: 7,174 Bytes
96feb73 41a34cd e932abd 96feb73 6e82d4a fb8db0f 41a34cd c308f77 e932abd 41a34cd 6e82d4a ae308b4 29bcc5f 96feb73 29bcc5f 96feb73 29bcc5f 96feb73 29bcc5f 41a34cd c2ef1c6 96feb73 29bcc5f 96feb73 29bcc5f 41a34cd 96feb73 c7f2652 fb8db0f 96feb73 29bcc5f c2ef1c6 29bcc5f c2ef1c6 96feb73 29bcc5f c308f77 29bcc5f 4f4785c 96feb73 29bcc5f c308f77 96feb73 c308f77 c2ef1c6 96feb73 c2ef1c6 29bcc5f 4f4785c 96feb73 41a34cd 29bcc5f c2ef1c6 fb8db0f 29bcc5f fb8db0f 96feb73 fb8db0f 4f4785c fb8db0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH, WANDB_DIR, RESOURCES
from data_generator import generate_data
from data_preprocessing import LatexImageDataModule
from model import Transformer
from utils import LogImageTexCallback, average_checkpoints
import argparse
import os
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
def check_setup():
print(
"Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if not os.path.isfile(DATAMODULE_PATH):
print("Generating default datamodule")
datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
torch.save(datamodule, DATAMODULE_PATH)
if not os.path.isfile(TOKENIZER_PATH):
print("Generating default tokenizer")
datamodule = torch.load(DATAMODULE_PATH)
datamodule.train_tokenizer()
def parse_args():
parser = argparse.ArgumentParser(description="Workflow: generate dataset, create datamodule, train model",
allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"gpus", type=int, help=f"Ids of gpus in range 0..{torch.cuda.device_count() - 1} to train on, "
"if not provided,\nthen trains on cpu. To see current gpu load, run nvtop", nargs="*")
parser.add_argument(
"-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
action="store_true", dest="log")
parser.add_argument(
"-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
data_args = ["size", "depth", "length", "fraction"]
parser.add_argument(
"-n", metavar=tuple(map(str.upper, data_args)), nargs=4, dest="data_args",
type=lambda x: int(x) if x.isdigit() else float(x),
help="Clear old dataset, create new and exit, args:"
"\nsize\tsize of new dataset"
"\ndepth\tmax_depth scope depth of generated equation, no less than 1"
"\nlength\tlength of equation will be in range length/2..length"
"\nfraction\tfraction of tex vocab to sample tokens from, float in range 0..1")
datamodule = torch.load(DATAMODULE_PATH)
datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
parser.add_argument(
"-d", metavar=tuple(map(str.upper, datamodule_args)), nargs=4, dest="datamodule_args", type=int,
help="Create new datamodule and exit, current parameters:\n" +
"\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
("dim_feedforward", 2048), ("dropout", 0.1)]
parser.add_argument(
"-t", metavar=tuple(args[0].upper() for args in transformer_args), dest="transformer_args",
nargs=len(transformer_args),
help="Transformer init args, default values:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
args = parser.parse_args()
if args.data_args:
args.data_args = dict(zip(data_args, args.data_args))
if args.datamodule_args:
args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
if args.transformer_args:
args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args))
else:
args.transformer_args = dict(transformer_args)
return args
def main():
check_setup()
args = parse_args()
if args.data_args:
generate_data(examples_count=args.data_args['size'],
max_depth=args.data_args['depth'],
equation_length=args.data_args['length'],
distribution_fraction=args.data_args['fraction'])
return
if args.datamodule_args:
datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
image_height=args.datamodule_args["image_height"],
batch_size=args.datamodule_args["batch_size"],
random_magnitude=args.datamodule_args["random_magnitude"])
datamodule.train_tokenizer()
tex_tokenizer = torch.load(TOKENIZER_PATH)
print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}")
torch.save(datamodule, DATAMODULE_PATH)
return
datamodule = torch.load(DATAMODULE_PATH)
tex_tokenizer = torch.load(TOKENIZER_PATH)
logger = None
callbacks = []
if args.log:
logger = WandbLogger(f"img2tex", save_dir=WANDB_DIR, log_model=True)
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
LearningRateMonitor(logging_interval="step"),
ModelCheckpoint(save_top_k=10,
every_n_train_steps=500,
monitor="val_loss",
mode="min",
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
trainer = Trainer(default_root_dir=TRAINER_DIR,
max_epochs=args.max_epochs,
accelerator="gpu" if args.gpus else "cpu",
gpus=args.gpus,
logger=logger,
strategy="ddp_find_unused_parameters_false",
enable_progress_bar=True,
callbacks=callbacks)
transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
num_decoder_layers=args.transformer_args["num_decoder_layers"],
d_model=args.transformer_args["d_model"],
nhead=args.transformer_args["nhead"],
dim_feedforward=args.transformer_args["dim_feedforward"],
dropout=args.transformer_args["dropout"],
image_width=datamodule.hparams["image_width"],
image_height=datamodule.hparams["image_height"],
tgt_vocab_size=tex_tokenizer.get_vocab_size(),
pad_idx=tex_tokenizer.token_to_id("[PAD]"))
trainer.fit(transformer, datamodule=datamodule)
trainer.test(transformer, datamodule=datamodule)
if args.log:
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
transformer.eval()
transformer.freeze()
torch.save(transformer.state_dict(), transformer_path)
print(f"Transformer ensemble saved to '{transformer_path}'")
if __name__ == "__main__":
main()
|