dkoshman commited on
Commit
96feb73
1 Parent(s): 29bcc5f

app.py interface, made functions more independent, ensemble, working prototype

Browse files
Files changed (7) hide show
  1. app.py +25 -13
  2. constants.py +8 -5
  3. data_preprocessing.py +11 -20
  4. generate.py +0 -23
  5. model.py +0 -21
  6. train.py +61 -29
  7. utils.py +31 -11
app.py CHANGED
@@ -1,17 +1,29 @@
1
- import streamlit as st
 
 
2
 
3
- st.markdown("### Hello, world!")
4
- st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
5
- # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
 
6
 
7
- text = st.text_area("TEXT HERE")
8
- # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
9
 
10
- # from transformers import pipeline
11
- # pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
12
- # raw_predictions = pipe(text)
13
- # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
 
 
 
 
14
 
15
- # st.markdown(f"{raw_predictions}")
16
- st.markdown(f"Simon says {text}!")
17
- # выводим результаты модели в текстовое поле, на потеху пользователю
 
 
 
 
 
 
 
1
+ from constants import RESOURCES
2
+ from data_preprocessing import RandomizeImageTransform
3
+ from utils import beam_search_decode
4
 
5
+ import streamlit as st
6
+ import PIL
7
+ import torch
8
+ import torchvision.transforms as T
9
 
10
+ MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
 
11
 
12
+ # TODO: make faster
13
+ transformer = torch.load(MODEL_PATH)
14
+ image_transform = T.Compose((
15
+ T.ToTensor(),
16
+ RandomizeImageTransform(width=transformer.hparams['image_width'],
17
+ height=transformer.hparams['image_height'],
18
+ random_magnitude=0)
19
+ ))
20
 
21
+ st.markdown("### Image to TeX")
22
+ st.image("resources/frontend/latex_example_1.png")
23
+ file_png = st.file_uploader("Upload a PNG image", type=([".png"]))
24
+ if file_png is not None:
25
+ image = PIL.Image.open(file_png)
26
+ image = image.convert("RGB")
27
+ tex = beam_search_decode(transformer, image, image_transform=image_transform)
28
+ st.latex(tex[0])
29
+ st.text(tex[0])
constants.py CHANGED
@@ -1,11 +1,14 @@
1
  PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
2
  GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
3
 
4
- DATA_DIR = "data"
5
- LATEX_PATH = "resources/latex.json"
6
- TRAINER_DIR = "resources/trainer"
7
- TOKENIZER_PATH = "resources/tokenizer.pt"
8
- DATAMODULE_PATH = "resources/datamodule.pt"
 
 
 
9
 
10
  NUM_DATALOADER_WORKERS = 4
11
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
 
1
  PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
2
  GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
3
 
4
+ DATA_DIR = "local/data"
5
+ WANDB_DIR = "local/wandb"
6
+ TRAINER_DIR = "local/trainer"
7
+
8
+ RESOURCES = "resources"
9
+ LATEX_PATH = RESOURCES + "/latex.json"
10
+ TOKENIZER_PATH = RESOURCES + "/tokenizer.pt"
11
+ DATAMODULE_PATH = RESOURCES + "/datamodule.pt"
12
 
13
  NUM_DATALOADER_WORKERS = 4
14
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
data_preprocessing.py CHANGED
@@ -73,26 +73,17 @@ class RandomizeImageTransform(object):
73
  """Standardize image and randomly augment"""
74
 
75
  def __init__(self, width, height, random_magnitude):
76
- if random_magnitude > 0:
77
- self.transform = T.Compose((
78
- T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
79
- saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
80
- T.Resize(height),
81
- T.Grayscale(),
82
- T.functional.invert,
83
- T.CenterCrop((height, width)),
84
- torch.Tensor.contiguous,
85
- T.RandAugment(magnitude=random_magnitude),
86
- T.ConvertImageDtype(torch.float32)
87
- ))
88
- else:
89
- self.transform = T.Compose((
90
- T.Resize(height),
91
- T.Grayscale(),
92
- T.functional.invert,
93
- T.CenterCrop((height, width)),
94
- T.ConvertImageDtype(torch.float32)
95
- ))
96
 
97
  def __call__(self, image):
98
  image = self.transform(image)
 
73
  """Standardize image and randomly augment"""
74
 
75
  def __init__(self, width, height, random_magnitude):
76
+ self.transform = T.Compose((
77
+ T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
78
+ saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
79
+ T.Resize(height, max_size=width),
80
+ T.Grayscale(),
81
+ T.functional.invert,
82
+ T.CenterCrop((height, width)),
83
+ torch.Tensor.contiguous,
84
+ T.RandAugment(magnitude=random_magnitude),
85
+ T.ConvertImageDtype(torch.float32)
86
+ ))
 
 
 
 
 
 
 
 
 
87
 
88
  def __call__(self, image):
89
  image = self.transform(image)
generate.py DELETED
@@ -1,23 +0,0 @@
1
- from data_generator import generate_data
2
-
3
- import argparse
4
-
5
-
6
- def parse_args():
7
- parser = argparse.ArgumentParser(description="Clear old dataset and generate new one")
8
- parser.add_argument("size", help="size of new dataset", type=int)
9
- parser.add_argument("depth", help="max_depth scope depth of generated equation, no less than 1", type=int)
10
- parser.add_argument("length", help="length of equation will be in range length/2..length", type=int)
11
- parser.add_argument("fraction", help="fraction of tex vocab to sample tokens from, float in range 0..1", type=float)
12
- args = parser.parse_args()
13
- return args
14
-
15
-
16
- def main():
17
- args = parse_args()
18
- generate_data(examples_count=args.size, max_depth=args.depth, equation_length=args.length,
19
- distribution_fraction=args.fraction)
20
-
21
-
22
- if __name__ == "__main__":
23
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -73,27 +73,6 @@ class TexEmbedding(nn.Module):
73
  return tex_ids_batch
74
 
75
 
76
- class ImageEncoder(nn.Module):
77
- """
78
- Given an image, returns its vector representation.
79
- """
80
-
81
- def __init__(self, image_width, image_height, d_model, num_layers=8):
82
- super().__init__()
83
- image_embedding = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=.1)
84
- encoder_layer = nn.TransformerEncoderLayer(
85
- d_model=d_model,
86
- nhead=8,
87
- dim_feedforward=2048,
88
- batch_first=True
89
- )
90
- transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
91
- self.encode = nn.Sequential(image_embedding, transformer_encoder)
92
-
93
- def forward(self, batch):
94
- return self.encode(batch)
95
-
96
-
97
  class Transformer(pl.LightningModule):
98
  def __init__(self,
99
  num_encoder_layers: int,
 
73
  return tex_ids_batch
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class Transformer(pl.LightningModule):
77
  def __init__(self,
78
  num_encoder_layers: int,
train.py CHANGED
@@ -1,7 +1,8 @@
1
- from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
 
2
  from data_preprocessing import LatexImageDataModule
3
  from model import Transformer
4
- from utils import LogImageTexCallback
5
 
6
  import argparse
7
  import os
@@ -11,44 +12,60 @@ from pytorch_lightning import Trainer
11
  import torch
12
 
13
 
14
- # TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
15
- # clear checkpoint data build full dataset, train export model to torchscript write spaces interface
16
-
17
  def check_setup():
 
 
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  if not os.path.isfile(DATAMODULE_PATH):
 
20
  datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
21
  torch.save(datamodule, DATAMODULE_PATH)
22
  if not os.path.isfile(TOKENIZER_PATH):
 
23
  datamodule = torch.load(DATAMODULE_PATH)
24
  datamodule.train_tokenizer()
25
 
26
 
27
  def parse_args():
28
- parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
29
-
30
- parser.add_argument("gpus", type=int, default=None,
31
- help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, "
32
- "if not provided, then trains on cpu", nargs="*")
33
- parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
34
- action="store_true", dest="log")
35
- parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
36
-
37
- datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  datamodule = torch.load(DATAMODULE_PATH)
40
- parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int,
41
- help="Create new datamodule and exit, current parameters:\n" +
42
- "\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
 
 
43
 
44
  transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
45
  ("dim_feedforward", 2048), ("dropout", 0.1)]
46
- parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args),
47
- help="Transformer init args, reference values:\n" +
48
- "\n".join(f"{k}\t{v}" for k, v in transformer_args))
 
49
 
50
  args = parser.parse_args()
51
-
 
52
  if args.datamodule_args:
53
  args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
54
 
@@ -63,6 +80,13 @@ def parse_args():
63
  def main():
64
  check_setup()
65
  args = parse_args()
 
 
 
 
 
 
 
66
  if args.datamodule_args:
67
  datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
68
  image_height=args.datamodule_args["image_height"],
@@ -79,23 +103,23 @@ def main():
79
  logger = None
80
  callbacks = []
81
  if args.log:
82
- logger = WandbLogger(f"img2tex", log_model=True)
83
- callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20),
84
  LearningRateMonitor(logging_interval="step"),
85
  ModelCheckpoint(save_top_k=10,
 
86
  monitor="val_loss",
87
  mode="min",
88
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
89
 
90
- trainer = Trainer(max_epochs=args.max_epochs,
91
- accelerator="cpu" if args.gpus is None else "gpu",
 
92
  gpus=args.gpus,
93
  logger=logger,
94
  strategy="ddp_find_unused_parameters_false",
95
  enable_progress_bar=True,
96
- default_root_dir=TRAINER_DIR,
97
- callbacks=callbacks,
98
- check_val_every_n_epoch=5)
99
 
100
  transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
101
  num_decoder_layers=args.transformer_args["num_decoder_layers"],
@@ -111,6 +135,14 @@ def main():
111
  trainer.fit(transformer, datamodule=datamodule)
112
  trainer.test(transformer, datamodule=datamodule)
113
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
116
  main()
 
1
+ from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH, WANDB_DIR, RESOURCES
2
+ from data_generator import generate_data
3
  from data_preprocessing import LatexImageDataModule
4
  from model import Transformer
5
+ from utils import LogImageTexCallback, average_checkpoints
6
 
7
  import argparse
8
  import os
 
12
  import torch
13
 
14
 
 
 
 
15
  def check_setup():
16
+ print(
17
+ "Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  if not os.path.isfile(DATAMODULE_PATH):
20
+ print("Generating default datamodule")
21
  datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
22
  torch.save(datamodule, DATAMODULE_PATH)
23
  if not os.path.isfile(TOKENIZER_PATH):
24
+ print("Generating default tokenizer")
25
  datamodule = torch.load(DATAMODULE_PATH)
26
  datamodule.train_tokenizer()
27
 
28
 
29
  def parse_args():
30
+ parser = argparse.ArgumentParser(description="Workflow: generate dataset, create datamodule, train model",
31
+ allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
32
+
33
+ parser.add_argument(
34
+ "gpus", type=int, help=f"Ids of gpus in range 0..{torch.cuda.device_count() - 1} to train on, "
35
+ "if not provided,\nthen trains on cpu. To see current gpu load, run nvtop", nargs="*")
36
+ parser.add_argument(
37
+ "-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
38
+ action="store_true", dest="log")
39
+ parser.add_argument(
40
+ "-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
41
+
42
+ data_args = ["size", "depth", "length", "fraction"]
43
+ parser.add_argument(
44
+ "-n", metavar=tuple(map(str.upper, data_args)), nargs=4, dest="data_args",
45
+ type=lambda x: int(x) if x.isdigit() else float(x),
46
+ help="Clear old dataset, create new and exit, args:"
47
+ "\nsize\tsize of new dataset"
48
+ "\ndepth\tmax_depth scope depth of generated equation, no less than 1"
49
+ "\nlength\tlength of equation will be in range length/2..length"
50
+ "\nfraction\tfraction of tex vocab to sample tokens from, float in range 0..1")
51
 
52
  datamodule = torch.load(DATAMODULE_PATH)
53
+ datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
54
+ parser.add_argument(
55
+ "-d", metavar=tuple(map(str.upper, datamodule_args)), nargs=4, dest="datamodule_args", type=int,
56
+ help="Create new datamodule and exit, current parameters:\n" +
57
+ "\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
58
 
59
  transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
60
  ("dim_feedforward", 2048), ("dropout", 0.1)]
61
+ parser.add_argument(
62
+ "-t", metavar=tuple(args[0].upper() for args in transformer_args), dest="transformer_args",
63
+ nargs=len(transformer_args),
64
+ help="Transformer init args, default values:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
65
 
66
  args = parser.parse_args()
67
+ if args.data_args:
68
+ args.data_args = dict(zip(data_args, args.data_args))
69
  if args.datamodule_args:
70
  args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
71
 
 
80
  def main():
81
  check_setup()
82
  args = parse_args()
83
+ if args.data_args:
84
+ generate_data(examples_count=args.data_args['size'],
85
+ max_depth=args.data_args['depth'],
86
+ equation_length=args.data_args['length'],
87
+ distribution_fraction=args.data_args['fraction'])
88
+ return
89
+
90
  if args.datamodule_args:
91
  datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
92
  image_height=args.datamodule_args["image_height"],
 
103
  logger = None
104
  callbacks = []
105
  if args.log:
106
+ logger = WandbLogger(f"img2tex", save_dir=WANDB_DIR, log_model=True)
107
+ callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
108
  LearningRateMonitor(logging_interval="step"),
109
  ModelCheckpoint(save_top_k=10,
110
+ every_n_train_steps=500,
111
  monitor="val_loss",
112
  mode="min",
113
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
114
 
115
+ trainer = Trainer(default_root_dir=TRAINER_DIR,
116
+ max_epochs=args.max_epochs,
117
+ accelerator="gpu" if args.gpus else "cpu",
118
  gpus=args.gpus,
119
  logger=logger,
120
  strategy="ddp_find_unused_parameters_false",
121
  enable_progress_bar=True,
122
+ callbacks=callbacks)
 
 
123
 
124
  transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
125
  num_decoder_layers=args.transformer_args["num_decoder_layers"],
 
135
  trainer.fit(transformer, datamodule=datamodule)
136
  trainer.test(transformer, datamodule=datamodule)
137
 
138
+ if args.log:
139
+ transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
140
+ transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
141
+ transformer.eval()
142
+ transformer.freeze()
143
+ torch.save(transformer.state_dict(), transformer_path)
144
+ print(f"Transformer ensemble saved to '{transformer_path}'")
145
+
146
 
147
  if __name__ == "__main__":
148
  main()
utils.py CHANGED
@@ -1,7 +1,7 @@
1
  from constants import TOKENIZER_PATH
2
- from data_preprocessing import RandomizeImageTransform
3
 
4
  import einops
 
5
  import random
6
  from pytorch_lightning.callbacks import Callback
7
  import torch
@@ -22,8 +22,7 @@ class LogImageTexCallback(Callback):
22
  return
23
  sample_id = random.randint(0, len(batch['images']) - 1)
24
  image = batch['images'][sample_id]
25
- texs_predicted, texs_ids = beam_search_decode(transformer, image, transform_image=False, top_k=self.top_k,
26
- max_length=self.max_length)
27
  image = self.tensor_to_PIL(image)
28
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
29
  self.logger.log_image(key="samples", images=[image],
@@ -31,9 +30,8 @@ class LogImageTexCallback(Callback):
31
 
32
 
33
  @torch.inference_mode()
34
- def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_length=100):
35
  """Performs decoding maintaining k best candidates"""
36
- assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
37
 
38
  def get_tgt_padding_mask(tgt):
39
  mask = tgt == tex_tokenizer.token_to_id("[SEP]")
@@ -41,12 +39,11 @@ def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_l
41
  mask = mask.to(transformer.device, torch.bool)
42
  return mask
43
 
 
 
 
 
44
  src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
45
- if transform_image:
46
- image_transform = RandomizeImageTransform(width=transformer.hparams["image_width"],
47
- height=transformer.hparams["image_width"],
48
- random_magnitude=0)
49
- src = image_transform(src)
50
  memory = transformer.encode(src)
51
 
52
  tex_tokenizer = torch.load(TOKENIZER_PATH)
@@ -82,4 +79,27 @@ def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_l
82
  padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
83
  tex_tokenizer.token_to_id("[PAD]")).tolist()
84
  texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
85
- return texs, candidates_tex_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from constants import TOKENIZER_PATH
 
2
 
3
  import einops
4
+ import os
5
  import random
6
  from pytorch_lightning.callbacks import Callback
7
  import torch
 
22
  return
23
  sample_id = random.randint(0, len(batch['images']) - 1)
24
  image = batch['images'][sample_id]
25
+ texs_predicted = beam_search_decode(transformer, image, top_k=self.top_k, max_length=self.max_length)
 
26
  image = self.tensor_to_PIL(image)
27
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
28
  self.logger.log_image(key="samples", images=[image],
 
30
 
31
 
32
  @torch.inference_mode()
33
+ def beam_search_decode(transformer, image, image_transform=None, top_k=10, max_length=100):
34
  """Performs decoding maintaining k best candidates"""
 
35
 
36
  def get_tgt_padding_mask(tgt):
37
  mask = tgt == tex_tokenizer.token_to_id("[SEP]")
 
39
  mask = mask.to(transformer.device, torch.bool)
40
  return mask
41
 
42
+ if image_transform:
43
+ image = image_transform(image)
44
+
45
+ assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
46
  src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
 
 
 
 
 
47
  memory = transformer.encode(src)
48
 
49
  tex_tokenizer = torch.load(TOKENIZER_PATH)
 
79
  padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
80
  tex_tokenizer.token_to_id("[PAD]")).tolist()
81
  texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
82
+ texs = [tex.replace("\\ ", "\\") for tex in texs]
83
+ return texs
84
+
85
+
86
+ def average_checkpoints(model_type, checkpoints_dir):
87
+ """Returns model averaged from checkpoints
88
+ Args:
89
+ :model_type: -- pytorch_lightning.LightningModule that corresponds to checkpoints
90
+ :checkpoints_dir: -- path to checkpoints
91
+ """
92
+ checkpoints = [checkpoint.path for checkpoint in os.scandir(checkpoints_dir)]
93
+ n_models = len(checkpoints)
94
+ assert n_models > 0
95
+ average_model = model_type.load_from_checkpoint(checkpoints[0])
96
+
97
+ for checkpoint in checkpoints[1:]:
98
+ model = model_type.load_from_checkpoint(checkpoint)
99
+ for weight, weight_to_add in zip(average_model.parameters(), model.parameters()):
100
+ weight.data.add_(weight_to_add.data)
101
+
102
+ for weight in average_model.parameters():
103
+ weight.data.divide_(n_models)
104
+
105
+ return average_model