Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
96feb73
1
Parent(s):
29bcc5f
app.py interface, made functions more independent, ensemble, working prototype
Browse files- app.py +25 -13
- constants.py +8 -5
- data_preprocessing.py +11 -20
- generate.py +0 -23
- model.py +0 -21
- train.py +61 -29
- utils.py +31 -11
app.py
CHANGED
@@ -1,17 +1,29 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
6 |
|
7 |
-
|
8 |
-
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
st.
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import RESOURCES
|
2 |
+
from data_preprocessing import RandomizeImageTransform
|
3 |
+
from utils import beam_search_decode
|
4 |
|
5 |
+
import streamlit as st
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as T
|
9 |
|
10 |
+
MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
|
|
|
11 |
|
12 |
+
# TODO: make faster
|
13 |
+
transformer = torch.load(MODEL_PATH)
|
14 |
+
image_transform = T.Compose((
|
15 |
+
T.ToTensor(),
|
16 |
+
RandomizeImageTransform(width=transformer.hparams['image_width'],
|
17 |
+
height=transformer.hparams['image_height'],
|
18 |
+
random_magnitude=0)
|
19 |
+
))
|
20 |
|
21 |
+
st.markdown("### Image to TeX")
|
22 |
+
st.image("resources/frontend/latex_example_1.png")
|
23 |
+
file_png = st.file_uploader("Upload a PNG image", type=([".png"]))
|
24 |
+
if file_png is not None:
|
25 |
+
image = PIL.Image.open(file_png)
|
26 |
+
image = image.convert("RGB")
|
27 |
+
tex = beam_search_decode(transformer, image, image_transform=image_transform)
|
28 |
+
st.latex(tex[0])
|
29 |
+
st.text(tex[0])
|
constants.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
|
2 |
GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
3 |
|
4 |
-
DATA_DIR = "data"
|
5 |
-
|
6 |
-
TRAINER_DIR = "
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
NUM_DATALOADER_WORKERS = 4
|
11 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
|
|
1 |
PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
|
2 |
GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
3 |
|
4 |
+
DATA_DIR = "local/data"
|
5 |
+
WANDB_DIR = "local/wandb"
|
6 |
+
TRAINER_DIR = "local/trainer"
|
7 |
+
|
8 |
+
RESOURCES = "resources"
|
9 |
+
LATEX_PATH = RESOURCES + "/latex.json"
|
10 |
+
TOKENIZER_PATH = RESOURCES + "/tokenizer.pt"
|
11 |
+
DATAMODULE_PATH = RESOURCES + "/datamodule.pt"
|
12 |
|
13 |
NUM_DATALOADER_WORKERS = 4
|
14 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
data_preprocessing.py
CHANGED
@@ -73,26 +73,17 @@ class RandomizeImageTransform(object):
|
|
73 |
"""Standardize image and randomly augment"""
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
))
|
88 |
-
else:
|
89 |
-
self.transform = T.Compose((
|
90 |
-
T.Resize(height),
|
91 |
-
T.Grayscale(),
|
92 |
-
T.functional.invert,
|
93 |
-
T.CenterCrop((height, width)),
|
94 |
-
T.ConvertImageDtype(torch.float32)
|
95 |
-
))
|
96 |
|
97 |
def __call__(self, image):
|
98 |
image = self.transform(image)
|
|
|
73 |
"""Standardize image and randomly augment"""
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
+
self.transform = T.Compose((
|
77 |
+
T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
|
78 |
+
saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
|
79 |
+
T.Resize(height, max_size=width),
|
80 |
+
T.Grayscale(),
|
81 |
+
T.functional.invert,
|
82 |
+
T.CenterCrop((height, width)),
|
83 |
+
torch.Tensor.contiguous,
|
84 |
+
T.RandAugment(magnitude=random_magnitude),
|
85 |
+
T.ConvertImageDtype(torch.float32)
|
86 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def __call__(self, image):
|
89 |
image = self.transform(image)
|
generate.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from data_generator import generate_data
|
2 |
-
|
3 |
-
import argparse
|
4 |
-
|
5 |
-
|
6 |
-
def parse_args():
|
7 |
-
parser = argparse.ArgumentParser(description="Clear old dataset and generate new one")
|
8 |
-
parser.add_argument("size", help="size of new dataset", type=int)
|
9 |
-
parser.add_argument("depth", help="max_depth scope depth of generated equation, no less than 1", type=int)
|
10 |
-
parser.add_argument("length", help="length of equation will be in range length/2..length", type=int)
|
11 |
-
parser.add_argument("fraction", help="fraction of tex vocab to sample tokens from, float in range 0..1", type=float)
|
12 |
-
args = parser.parse_args()
|
13 |
-
return args
|
14 |
-
|
15 |
-
|
16 |
-
def main():
|
17 |
-
args = parse_args()
|
18 |
-
generate_data(examples_count=args.size, max_depth=args.depth, equation_length=args.length,
|
19 |
-
distribution_fraction=args.fraction)
|
20 |
-
|
21 |
-
|
22 |
-
if __name__ == "__main__":
|
23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
CHANGED
@@ -73,27 +73,6 @@ class TexEmbedding(nn.Module):
|
|
73 |
return tex_ids_batch
|
74 |
|
75 |
|
76 |
-
class ImageEncoder(nn.Module):
|
77 |
-
"""
|
78 |
-
Given an image, returns its vector representation.
|
79 |
-
"""
|
80 |
-
|
81 |
-
def __init__(self, image_width, image_height, d_model, num_layers=8):
|
82 |
-
super().__init__()
|
83 |
-
image_embedding = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=.1)
|
84 |
-
encoder_layer = nn.TransformerEncoderLayer(
|
85 |
-
d_model=d_model,
|
86 |
-
nhead=8,
|
87 |
-
dim_feedforward=2048,
|
88 |
-
batch_first=True
|
89 |
-
)
|
90 |
-
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
91 |
-
self.encode = nn.Sequential(image_embedding, transformer_encoder)
|
92 |
-
|
93 |
-
def forward(self, batch):
|
94 |
-
return self.encode(batch)
|
95 |
-
|
96 |
-
|
97 |
class Transformer(pl.LightningModule):
|
98 |
def __init__(self,
|
99 |
num_encoder_layers: int,
|
|
|
73 |
return tex_ids_batch
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
class Transformer(pl.LightningModule):
|
77 |
def __init__(self,
|
78 |
num_encoder_layers: int,
|
train.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
|
|
|
2 |
from data_preprocessing import LatexImageDataModule
|
3 |
from model import Transformer
|
4 |
-
from utils import LogImageTexCallback
|
5 |
|
6 |
import argparse
|
7 |
import os
|
@@ -11,44 +12,60 @@ from pytorch_lightning import Trainer
|
|
11 |
import torch
|
12 |
|
13 |
|
14 |
-
# TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
|
15 |
-
# clear checkpoint data build full dataset, train export model to torchscript write spaces interface
|
16 |
-
|
17 |
def check_setup():
|
|
|
|
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
if not os.path.isfile(DATAMODULE_PATH):
|
|
|
20 |
datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
|
21 |
torch.save(datamodule, DATAMODULE_PATH)
|
22 |
if not os.path.isfile(TOKENIZER_PATH):
|
|
|
23 |
datamodule = torch.load(DATAMODULE_PATH)
|
24 |
datamodule.train_tokenizer()
|
25 |
|
26 |
|
27 |
def parse_args():
|
28 |
-
parser = argparse.ArgumentParser(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
datamodule = torch.load(DATAMODULE_PATH)
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
|
44 |
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
|
45 |
("dim_feedforward", 2048), ("dropout", 0.1)]
|
46 |
-
parser.add_argument(
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
args = parser.parse_args()
|
51 |
-
|
|
|
52 |
if args.datamodule_args:
|
53 |
args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
|
54 |
|
@@ -63,6 +80,13 @@ def parse_args():
|
|
63 |
def main():
|
64 |
check_setup()
|
65 |
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if args.datamodule_args:
|
67 |
datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
|
68 |
image_height=args.datamodule_args["image_height"],
|
@@ -79,23 +103,23 @@ def main():
|
|
79 |
logger = None
|
80 |
callbacks = []
|
81 |
if args.log:
|
82 |
-
logger = WandbLogger(f"img2tex", log_model=True)
|
83 |
-
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=
|
84 |
LearningRateMonitor(logging_interval="step"),
|
85 |
ModelCheckpoint(save_top_k=10,
|
|
|
86 |
monitor="val_loss",
|
87 |
mode="min",
|
88 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
89 |
|
90 |
-
trainer = Trainer(
|
91 |
-
|
|
|
92 |
gpus=args.gpus,
|
93 |
logger=logger,
|
94 |
strategy="ddp_find_unused_parameters_false",
|
95 |
enable_progress_bar=True,
|
96 |
-
|
97 |
-
callbacks=callbacks,
|
98 |
-
check_val_every_n_epoch=5)
|
99 |
|
100 |
transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
|
101 |
num_decoder_layers=args.transformer_args["num_decoder_layers"],
|
@@ -111,6 +135,14 @@ def main():
|
|
111 |
trainer.fit(transformer, datamodule=datamodule)
|
112 |
trainer.test(transformer, datamodule=datamodule)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
if __name__ == "__main__":
|
116 |
main()
|
|
|
1 |
+
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH, WANDB_DIR, RESOURCES
|
2 |
+
from data_generator import generate_data
|
3 |
from data_preprocessing import LatexImageDataModule
|
4 |
from model import Transformer
|
5 |
+
from utils import LogImageTexCallback, average_checkpoints
|
6 |
|
7 |
import argparse
|
8 |
import os
|
|
|
12 |
import torch
|
13 |
|
14 |
|
|
|
|
|
|
|
15 |
def check_setup():
|
16 |
+
print(
|
17 |
+
"Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
if not os.path.isfile(DATAMODULE_PATH):
|
20 |
+
print("Generating default datamodule")
|
21 |
datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
|
22 |
torch.save(datamodule, DATAMODULE_PATH)
|
23 |
if not os.path.isfile(TOKENIZER_PATH):
|
24 |
+
print("Generating default tokenizer")
|
25 |
datamodule = torch.load(DATAMODULE_PATH)
|
26 |
datamodule.train_tokenizer()
|
27 |
|
28 |
|
29 |
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser(description="Workflow: generate dataset, create datamodule, train model",
|
31 |
+
allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
"gpus", type=int, help=f"Ids of gpus in range 0..{torch.cuda.device_count() - 1} to train on, "
|
35 |
+
"if not provided,\nthen trains on cpu. To see current gpu load, run nvtop", nargs="*")
|
36 |
+
parser.add_argument(
|
37 |
+
"-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
|
38 |
+
action="store_true", dest="log")
|
39 |
+
parser.add_argument(
|
40 |
+
"-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
|
41 |
+
|
42 |
+
data_args = ["size", "depth", "length", "fraction"]
|
43 |
+
parser.add_argument(
|
44 |
+
"-n", metavar=tuple(map(str.upper, data_args)), nargs=4, dest="data_args",
|
45 |
+
type=lambda x: int(x) if x.isdigit() else float(x),
|
46 |
+
help="Clear old dataset, create new and exit, args:"
|
47 |
+
"\nsize\tsize of new dataset"
|
48 |
+
"\ndepth\tmax_depth scope depth of generated equation, no less than 1"
|
49 |
+
"\nlength\tlength of equation will be in range length/2..length"
|
50 |
+
"\nfraction\tfraction of tex vocab to sample tokens from, float in range 0..1")
|
51 |
|
52 |
datamodule = torch.load(DATAMODULE_PATH)
|
53 |
+
datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
|
54 |
+
parser.add_argument(
|
55 |
+
"-d", metavar=tuple(map(str.upper, datamodule_args)), nargs=4, dest="datamodule_args", type=int,
|
56 |
+
help="Create new datamodule and exit, current parameters:\n" +
|
57 |
+
"\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
|
58 |
|
59 |
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
|
60 |
("dim_feedforward", 2048), ("dropout", 0.1)]
|
61 |
+
parser.add_argument(
|
62 |
+
"-t", metavar=tuple(args[0].upper() for args in transformer_args), dest="transformer_args",
|
63 |
+
nargs=len(transformer_args),
|
64 |
+
help="Transformer init args, default values:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
|
65 |
|
66 |
args = parser.parse_args()
|
67 |
+
if args.data_args:
|
68 |
+
args.data_args = dict(zip(data_args, args.data_args))
|
69 |
if args.datamodule_args:
|
70 |
args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
|
71 |
|
|
|
80 |
def main():
|
81 |
check_setup()
|
82 |
args = parse_args()
|
83 |
+
if args.data_args:
|
84 |
+
generate_data(examples_count=args.data_args['size'],
|
85 |
+
max_depth=args.data_args['depth'],
|
86 |
+
equation_length=args.data_args['length'],
|
87 |
+
distribution_fraction=args.data_args['fraction'])
|
88 |
+
return
|
89 |
+
|
90 |
if args.datamodule_args:
|
91 |
datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
|
92 |
image_height=args.datamodule_args["image_height"],
|
|
|
103 |
logger = None
|
104 |
callbacks = []
|
105 |
if args.log:
|
106 |
+
logger = WandbLogger(f"img2tex", save_dir=WANDB_DIR, log_model=True)
|
107 |
+
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
|
108 |
LearningRateMonitor(logging_interval="step"),
|
109 |
ModelCheckpoint(save_top_k=10,
|
110 |
+
every_n_train_steps=500,
|
111 |
monitor="val_loss",
|
112 |
mode="min",
|
113 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
114 |
|
115 |
+
trainer = Trainer(default_root_dir=TRAINER_DIR,
|
116 |
+
max_epochs=args.max_epochs,
|
117 |
+
accelerator="gpu" if args.gpus else "cpu",
|
118 |
gpus=args.gpus,
|
119 |
logger=logger,
|
120 |
strategy="ddp_find_unused_parameters_false",
|
121 |
enable_progress_bar=True,
|
122 |
+
callbacks=callbacks)
|
|
|
|
|
123 |
|
124 |
transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
|
125 |
num_decoder_layers=args.transformer_args["num_decoder_layers"],
|
|
|
135 |
trainer.fit(transformer, datamodule=datamodule)
|
136 |
trainer.test(transformer, datamodule=datamodule)
|
137 |
|
138 |
+
if args.log:
|
139 |
+
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
|
140 |
+
transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
|
141 |
+
transformer.eval()
|
142 |
+
transformer.freeze()
|
143 |
+
torch.save(transformer.state_dict(), transformer_path)
|
144 |
+
print(f"Transformer ensemble saved to '{transformer_path}'")
|
145 |
+
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
main()
|
utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from constants import TOKENIZER_PATH
|
2 |
-
from data_preprocessing import RandomizeImageTransform
|
3 |
|
4 |
import einops
|
|
|
5 |
import random
|
6 |
from pytorch_lightning.callbacks import Callback
|
7 |
import torch
|
@@ -22,8 +22,7 @@ class LogImageTexCallback(Callback):
|
|
22 |
return
|
23 |
sample_id = random.randint(0, len(batch['images']) - 1)
|
24 |
image = batch['images'][sample_id]
|
25 |
-
texs_predicted
|
26 |
-
max_length=self.max_length)
|
27 |
image = self.tensor_to_PIL(image)
|
28 |
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
|
29 |
self.logger.log_image(key="samples", images=[image],
|
@@ -31,9 +30,8 @@ class LogImageTexCallback(Callback):
|
|
31 |
|
32 |
|
33 |
@torch.inference_mode()
|
34 |
-
def beam_search_decode(transformer, image,
|
35 |
"""Performs decoding maintaining k best candidates"""
|
36 |
-
assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
|
37 |
|
38 |
def get_tgt_padding_mask(tgt):
|
39 |
mask = tgt == tex_tokenizer.token_to_id("[SEP]")
|
@@ -41,12 +39,11 @@ def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_l
|
|
41 |
mask = mask.to(transformer.device, torch.bool)
|
42 |
return mask
|
43 |
|
|
|
|
|
|
|
|
|
44 |
src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
|
45 |
-
if transform_image:
|
46 |
-
image_transform = RandomizeImageTransform(width=transformer.hparams["image_width"],
|
47 |
-
height=transformer.hparams["image_width"],
|
48 |
-
random_magnitude=0)
|
49 |
-
src = image_transform(src)
|
50 |
memory = transformer.encode(src)
|
51 |
|
52 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
@@ -82,4 +79,27 @@ def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_l
|
|
82 |
padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
|
83 |
tex_tokenizer.token_to_id("[PAD]")).tolist()
|
84 |
texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from constants import TOKENIZER_PATH
|
|
|
2 |
|
3 |
import einops
|
4 |
+
import os
|
5 |
import random
|
6 |
from pytorch_lightning.callbacks import Callback
|
7 |
import torch
|
|
|
22 |
return
|
23 |
sample_id = random.randint(0, len(batch['images']) - 1)
|
24 |
image = batch['images'][sample_id]
|
25 |
+
texs_predicted = beam_search_decode(transformer, image, top_k=self.top_k, max_length=self.max_length)
|
|
|
26 |
image = self.tensor_to_PIL(image)
|
27 |
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
|
28 |
self.logger.log_image(key="samples", images=[image],
|
|
|
30 |
|
31 |
|
32 |
@torch.inference_mode()
|
33 |
+
def beam_search_decode(transformer, image, image_transform=None, top_k=10, max_length=100):
|
34 |
"""Performs decoding maintaining k best candidates"""
|
|
|
35 |
|
36 |
def get_tgt_padding_mask(tgt):
|
37 |
mask = tgt == tex_tokenizer.token_to_id("[SEP]")
|
|
|
39 |
mask = mask.to(transformer.device, torch.bool)
|
40 |
return mask
|
41 |
|
42 |
+
if image_transform:
|
43 |
+
image = image_transform(image)
|
44 |
+
|
45 |
+
assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
|
46 |
src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
|
|
|
|
|
|
|
|
|
|
|
47 |
memory = transformer.encode(src)
|
48 |
|
49 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
|
|
79 |
padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
|
80 |
tex_tokenizer.token_to_id("[PAD]")).tolist()
|
81 |
texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
|
82 |
+
texs = [tex.replace("\\ ", "\\") for tex in texs]
|
83 |
+
return texs
|
84 |
+
|
85 |
+
|
86 |
+
def average_checkpoints(model_type, checkpoints_dir):
|
87 |
+
"""Returns model averaged from checkpoints
|
88 |
+
Args:
|
89 |
+
:model_type: -- pytorch_lightning.LightningModule that corresponds to checkpoints
|
90 |
+
:checkpoints_dir: -- path to checkpoints
|
91 |
+
"""
|
92 |
+
checkpoints = [checkpoint.path for checkpoint in os.scandir(checkpoints_dir)]
|
93 |
+
n_models = len(checkpoints)
|
94 |
+
assert n_models > 0
|
95 |
+
average_model = model_type.load_from_checkpoint(checkpoints[0])
|
96 |
+
|
97 |
+
for checkpoint in checkpoints[1:]:
|
98 |
+
model = model_type.load_from_checkpoint(checkpoint)
|
99 |
+
for weight, weight_to_add in zip(average_model.parameters(), model.parameters()):
|
100 |
+
weight.data.add_(weight_to_add.data)
|
101 |
+
|
102 |
+
for weight in average_model.parameters():
|
103 |
+
weight.data.divide_(n_models)
|
104 |
+
|
105 |
+
return average_model
|