Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
29bcc5f
1
Parent(s):
c308f77
beam search
Browse files- constants.py +1 -0
- data_preprocessing.py +2 -3
- model.py +13 -1
- train.py +67 -39
- utils.py +59 -92
constants.py
CHANGED
@@ -5,6 +5,7 @@ DATA_DIR = "data"
|
|
5 |
LATEX_PATH = "resources/latex.json"
|
6 |
TRAINER_DIR = "resources/trainer"
|
7 |
TOKENIZER_PATH = "resources/tokenizer.pt"
|
|
|
8 |
|
9 |
NUM_DATALOADER_WORKERS = 4
|
10 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
|
|
5 |
LATEX_PATH = "resources/latex.json"
|
6 |
TRAINER_DIR = "resources/trainer"
|
7 |
TOKENIZER_PATH = "resources/tokenizer.pt"
|
8 |
+
DATAMODULE_PATH = "resources/datamodule.pt"
|
9 |
|
10 |
NUM_DATALOADER_WORKERS = 4
|
11 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
data_preprocessing.py
CHANGED
@@ -119,7 +119,6 @@ def generate_tex_tokenizer(dataloader):
|
|
119 |
|
120 |
texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
|
121 |
|
122 |
-
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
123 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
124 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
125 |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
@@ -150,10 +149,10 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
150 |
self.batch_size = batch_size
|
151 |
self.save_hyperparameters()
|
152 |
|
153 |
-
def
|
154 |
tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
|
155 |
-
print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
|
156 |
torch.save(tokenizer, TOKENIZER_PATH)
|
|
|
157 |
|
158 |
def _shared_dataloader(self, dataset, **kwargs):
|
159 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
|
|
119 |
|
120 |
texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
|
121 |
|
|
|
122 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
123 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
124 |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
|
|
149 |
self.batch_size = batch_size
|
150 |
self.save_hyperparameters()
|
151 |
|
152 |
+
def train_tokenizer(self):
|
153 |
tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
|
|
|
154 |
torch.save(tokenizer, TOKENIZER_PATH)
|
155 |
+
return tokenizer
|
156 |
|
157 |
def _shared_dataloader(self, dataset, **kwargs):
|
158 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
model.py
CHANGED
@@ -130,12 +130,24 @@ class Transformer(pl.LightningModule):
|
|
130 |
|
131 |
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
|
132 |
tgt_padding_mask=None):
|
|
|
|
|
|
|
|
|
133 |
src = self.src_tok_emb(src)
|
134 |
tgt = self.tgt_tok_emb(tgt)
|
135 |
-
|
136 |
outs = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
137 |
return self.generator(outs)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
def _shared_step(self, batch):
|
140 |
src = batch['images']
|
141 |
tgt = batch['tex_ids']
|
|
|
130 |
|
131 |
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
|
132 |
tgt_padding_mask=None):
|
133 |
+
"""The positions of masks with ``True``
|
134 |
+
are not allowed to attend while ``False`` values will be unchanged.
|
135 |
+
The positions of padding masks with the
|
136 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged."""
|
137 |
src = self.src_tok_emb(src)
|
138 |
tgt = self.tgt_tok_emb(tgt)
|
|
|
139 |
outs = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
140 |
return self.generator(outs)
|
141 |
|
142 |
+
def encode(self, src, src_mask=None, src_padding_mask=None):
|
143 |
+
src = self.src_tok_emb(src)
|
144 |
+
return self.transformer.encoder(src, src_mask, src_padding_mask)
|
145 |
+
|
146 |
+
def decode(self, tgt, memory=None, tgt_mask=None, memory_mask=None, tgt_padding_mask=None):
|
147 |
+
tgt = self.tgt_tok_emb(tgt)
|
148 |
+
outs = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask, tgt_padding_mask)
|
149 |
+
return self.generator(outs)
|
150 |
+
|
151 |
def _shared_step(self, batch):
|
152 |
src = batch['images']
|
153 |
tgt = batch['tex_ids']
|
train.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from constants import TRAINER_DIR
|
2 |
from data_preprocessing import LatexImageDataModule
|
3 |
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
@@ -11,77 +11,105 @@ from pytorch_lightning import Trainer
|
|
11 |
import torch
|
12 |
|
13 |
|
14 |
-
# TODO: update python,
|
15 |
-
#
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def parse_args():
|
19 |
parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
|
20 |
|
21 |
-
parser.add_argument("
|
22 |
-
|
23 |
-
|
24 |
-
parser.add_argument("-l", "-log", help="
|
25 |
action="store_true", dest="log")
|
26 |
-
parser.add_argument("-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
|
33 |
("dim_feedforward", 2048), ("dropout", 0.1)]
|
34 |
-
parser.add_argument("-t", "
|
35 |
-
help="
|
|
|
36 |
|
37 |
args = parser.parse_args()
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
return args
|
42 |
|
43 |
|
44 |
def main():
|
|
|
45 |
args = parse_args()
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
|
|
50 |
if args.log:
|
51 |
logger = WandbLogger(f"img2tex", log_model=True)
|
52 |
-
callbacks = [LogImageTexCallback(logger),
|
53 |
-
LearningRateMonitor(logging_interval=
|
54 |
ModelCheckpoint(save_top_k=10,
|
55 |
monitor="val_loss",
|
56 |
mode="min",
|
57 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
58 |
-
else:
|
59 |
-
logger = None
|
60 |
-
callbacks = []
|
61 |
|
62 |
trainer = Trainer(max_epochs=args.max_epochs,
|
63 |
accelerator="cpu" if args.gpus is None else "gpu",
|
64 |
gpus=args.gpus,
|
65 |
logger=logger,
|
66 |
-
strategy="
|
67 |
enable_progress_bar=True,
|
68 |
default_root_dir=TRAINER_DIR,
|
69 |
callbacks=callbacks,
|
70 |
check_val_every_n_epoch=5)
|
71 |
|
72 |
-
transformer = Transformer(num_encoder_layers=args.transformer_args[
|
73 |
-
num_decoder_layers=args.transformer_args[
|
74 |
-
d_model=args.transformer_args[
|
75 |
-
nhead=args.transformer_args[
|
76 |
-
dim_feedforward=args.transformer_args[
|
77 |
-
dropout=args.transformer_args[
|
78 |
-
image_width=datamodule.hparams[
|
79 |
-
image_height=datamodule.hparams[
|
80 |
-
tgt_vocab_size=
|
81 |
-
pad_idx=
|
82 |
|
83 |
trainer.fit(transformer, datamodule=datamodule)
|
84 |
-
trainer.
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|
|
|
1 |
+
from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
|
2 |
from data_preprocessing import LatexImageDataModule
|
3 |
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
|
|
11 |
import torch
|
12 |
|
13 |
|
14 |
+
# TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
|
15 |
+
# clear checkpoint data build full dataset, train export model to torchscript write spaces interface
|
16 |
+
|
17 |
+
def check_setup():
|
18 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
+
if not os.path.isfile(DATAMODULE_PATH):
|
20 |
+
datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
|
21 |
+
torch.save(datamodule, DATAMODULE_PATH)
|
22 |
+
if not os.path.isfile(TOKENIZER_PATH):
|
23 |
+
datamodule = torch.load(DATAMODULE_PATH)
|
24 |
+
datamodule.train_tokenizer()
|
25 |
+
|
26 |
|
27 |
def parse_args():
|
28 |
parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
|
29 |
|
30 |
+
parser.add_argument("gpus", type=int, default=None,
|
31 |
+
help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, "
|
32 |
+
"if not provided, then trains on cpu", nargs="*")
|
33 |
+
parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
|
34 |
action="store_true", dest="log")
|
35 |
+
parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
|
36 |
+
|
37 |
+
datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
|
38 |
+
|
39 |
+
datamodule = torch.load(DATAMODULE_PATH)
|
40 |
+
parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int,
|
41 |
+
help="Create new datamodule and exit, current parameters:\n" +
|
42 |
+
"\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
|
43 |
+
|
44 |
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
|
45 |
("dim_feedforward", 2048), ("dropout", 0.1)]
|
46 |
+
parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args),
|
47 |
+
help="Transformer init args, reference values:\n" +
|
48 |
+
"\n".join(f"{k}\t{v}" for k, v in transformer_args))
|
49 |
|
50 |
args = parser.parse_args()
|
51 |
+
|
52 |
+
if args.datamodule_args:
|
53 |
+
args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
|
54 |
+
|
55 |
+
if args.transformer_args:
|
56 |
+
args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args))
|
57 |
+
else:
|
58 |
+
args.transformer_args = dict(transformer_args)
|
59 |
+
|
60 |
return args
|
61 |
|
62 |
|
63 |
def main():
|
64 |
+
check_setup()
|
65 |
args = parse_args()
|
66 |
+
if args.datamodule_args:
|
67 |
+
datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
|
68 |
+
image_height=args.datamodule_args["image_height"],
|
69 |
+
batch_size=args.datamodule_args["batch_size"],
|
70 |
+
random_magnitude=args.datamodule_args["random_magnitude"])
|
71 |
+
datamodule.train_tokenizer()
|
72 |
+
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
73 |
+
print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}")
|
74 |
+
torch.save(datamodule, DATAMODULE_PATH)
|
75 |
+
return
|
76 |
|
77 |
+
datamodule = torch.load(DATAMODULE_PATH)
|
78 |
+
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
79 |
+
logger = None
|
80 |
+
callbacks = []
|
81 |
if args.log:
|
82 |
logger = WandbLogger(f"img2tex", log_model=True)
|
83 |
+
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20),
|
84 |
+
LearningRateMonitor(logging_interval="step"),
|
85 |
ModelCheckpoint(save_top_k=10,
|
86 |
monitor="val_loss",
|
87 |
mode="min",
|
88 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
|
|
|
|
|
|
89 |
|
90 |
trainer = Trainer(max_epochs=args.max_epochs,
|
91 |
accelerator="cpu" if args.gpus is None else "gpu",
|
92 |
gpus=args.gpus,
|
93 |
logger=logger,
|
94 |
+
strategy="ddp_find_unused_parameters_false",
|
95 |
enable_progress_bar=True,
|
96 |
default_root_dir=TRAINER_DIR,
|
97 |
callbacks=callbacks,
|
98 |
check_val_every_n_epoch=5)
|
99 |
|
100 |
+
transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
|
101 |
+
num_decoder_layers=args.transformer_args["num_decoder_layers"],
|
102 |
+
d_model=args.transformer_args["d_model"],
|
103 |
+
nhead=args.transformer_args["nhead"],
|
104 |
+
dim_feedforward=args.transformer_args["dim_feedforward"],
|
105 |
+
dropout=args.transformer_args["dropout"],
|
106 |
+
image_width=datamodule.hparams["image_width"],
|
107 |
+
image_height=datamodule.hparams["image_height"],
|
108 |
+
tgt_vocab_size=tex_tokenizer.get_vocab_size(),
|
109 |
+
pad_idx=tex_tokenizer.token_to_id("[PAD]"))
|
110 |
|
111 |
trainer.fit(transformer, datamodule=datamodule)
|
112 |
+
trainer.test(transformer, datamodule=datamodule)
|
113 |
|
114 |
|
115 |
if __name__ == "__main__":
|
utils.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
from constants import TOKENIZER_PATH
|
|
|
2 |
|
3 |
import einops
|
4 |
import random
|
5 |
from pytorch_lightning.callbacks import Callback
|
6 |
import torch
|
|
|
7 |
from torchvision import transforms
|
8 |
|
9 |
|
10 |
class LogImageTexCallback(Callback):
|
11 |
-
def __init__(self, logger):
|
12 |
self.logger = logger
|
|
|
|
|
13 |
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
|
14 |
self.tensor_to_PIL = transforms.ToPILImage()
|
15 |
|
@@ -18,101 +22,64 @@ class LogImageTexCallback(Callback):
|
|
18 |
return
|
19 |
sample_id = random.randint(0, len(batch['images']) - 1)
|
20 |
image = batch['images'][sample_id]
|
21 |
-
|
|
|
22 |
image = self.tensor_to_PIL(image)
|
23 |
-
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int))
|
24 |
-
|
25 |
-
|
26 |
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# strategy=TRAINER_STRATEGY,
|
39 |
-
# enable_progress_bar=True,
|
40 |
-
# enable_checkpointing=False,
|
41 |
-
# auto_scale_batch_size=True,
|
42 |
-
# num_sanity_val_steps=0,
|
43 |
-
# logger=False
|
44 |
-
# )
|
45 |
-
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
|
46 |
-
# torch.save(datamodule, DATASET_PATH)
|
47 |
-
# TUNER_DIR = "resources/pl_tuner_checkpoints"
|
48 |
-
# from pytorch_lightning import seed_everything
|
49 |
-
# parser.add_argument(
|
50 |
-
# "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
|
51 |
-
# action="store_true", dest="deterministic"
|
52 |
-
# )
|
53 |
-
# if args.deterministic:
|
54 |
-
# seed_everything(42, workers=True)
|
55 |
-
# def generate_normalize_transform(dataset: TexImageDataset):
|
56 |
-
# """Returns a normalize layer with mean and std computed after iterating over dataset"""
|
57 |
-
#
|
58 |
-
# mean = 0
|
59 |
-
# std = 0
|
60 |
-
# for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
|
61 |
-
# image = item['image']
|
62 |
-
# mean += image.mean()
|
63 |
-
# std += image.std()
|
64 |
-
#
|
65 |
-
# mean /= len(dataset)
|
66 |
-
# std /= len(dataset)
|
67 |
-
# normalize = T.Normalize(mean, std)
|
68 |
-
# return normalize
|
69 |
-
# class _TransformerTuner(Transformer):
|
70 |
-
# """
|
71 |
-
# When using trainer.tune, batches from dataloader get passed directly to forward,
|
72 |
-
# so this subclass takes care of that
|
73 |
-
# """
|
74 |
-
#
|
75 |
-
# def forward(self, batch, batch_idx):
|
76 |
-
# src = batch['images']
|
77 |
-
# tgt = batch['tex_ids']
|
78 |
-
# tgt_input = tgt[:, :-1]
|
79 |
-
# tgt_output = tgt[:, 1:]
|
80 |
-
# src_mask = None
|
81 |
-
# tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
82 |
-
# torch.ByteTensor.dtype)
|
83 |
-
# memory_mask = None
|
84 |
-
# src_padding_mask = None
|
85 |
-
# tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
86 |
-
# tgt_padding_mask = tgt_padding_mask.masked_fill(
|
87 |
-
# tgt_padding_mask == 0, float('-inf')
|
88 |
-
# ).masked_fill(
|
89 |
-
# tgt_padding_mask == 1, 0
|
90 |
-
# )
|
91 |
-
#
|
92 |
-
# src = self.src_tok_emb(src)
|
93 |
-
# tgt_input = self.tgt_tok_emb(tgt_input)
|
94 |
-
# outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
95 |
-
# outs = self.generator(outs)
|
96 |
-
#
|
97 |
-
# loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
98 |
-
# return loss
|
99 |
-
#
|
100 |
-
# def validation_step(self, batch, batch_idx):
|
101 |
-
# return self(batch, batch_idx)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
@torch.inference_mode()
|
105 |
-
def decode(transformer, image):
|
106 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from constants import TOKENIZER_PATH
|
2 |
+
from data_preprocessing import RandomizeImageTransform
|
3 |
|
4 |
import einops
|
5 |
import random
|
6 |
from pytorch_lightning.callbacks import Callback
|
7 |
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
from torchvision import transforms
|
10 |
|
11 |
|
12 |
class LogImageTexCallback(Callback):
|
13 |
+
def __init__(self, logger, top_k, max_length):
|
14 |
self.logger = logger
|
15 |
+
self.top_k = top_k
|
16 |
+
self.max_length = max_length
|
17 |
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
|
18 |
self.tensor_to_PIL = transforms.ToPILImage()
|
19 |
|
|
|
22 |
return
|
23 |
sample_id = random.randint(0, len(batch['images']) - 1)
|
24 |
image = batch['images'][sample_id]
|
25 |
+
texs_predicted, texs_ids = beam_search_decode(transformer, image, transform_image=False, top_k=self.top_k,
|
26 |
+
max_length=self.max_length)
|
27 |
image = self.tensor_to_PIL(image)
|
28 |
+
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
|
29 |
+
self.logger.log_image(key="samples", images=[image],
|
30 |
+
caption=[f"True: {tex_true}\nPredicted: " + "\n".join(texs_predicted)])
|
31 |
|
32 |
|
33 |
+
@torch.inference_mode()
|
34 |
+
def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_length=100):
|
35 |
+
"""Performs decoding maintaining k best candidates"""
|
36 |
+
assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
|
37 |
|
38 |
+
def get_tgt_padding_mask(tgt):
|
39 |
+
mask = tgt == tex_tokenizer.token_to_id("[SEP]")
|
40 |
+
mask = torch.cumsum(mask, dim=1)
|
41 |
+
mask = mask.to(transformer.device, torch.bool)
|
42 |
+
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
|
45 |
+
if transform_image:
|
46 |
+
image_transform = RandomizeImageTransform(width=transformer.hparams["image_width"],
|
47 |
+
height=transformer.hparams["image_width"],
|
48 |
+
random_magnitude=0)
|
49 |
+
src = image_transform(src)
|
50 |
+
memory = transformer.encode(src)
|
51 |
|
|
|
|
|
52 |
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
53 |
+
candidates_tex_ids = [[tex_tokenizer.token_to_id("[CLS]")]]
|
54 |
+
candidates_log_prob = torch.tensor([0], dtype=torch.float, device=transformer.device)
|
55 |
+
|
56 |
+
while candidates_tex_ids[0][-1] != tex_tokenizer.token_to_id("[SEP]") and len(candidates_tex_ids[0]) < max_length:
|
57 |
+
candidates_tex_ids = torch.tensor(candidates_tex_ids, dtype=torch.float, device=transformer.device)
|
58 |
+
tgt_mask = transformer.transformer.generate_square_subsequent_mask(candidates_tex_ids.shape[1]).to(
|
59 |
+
transformer.device, torch.bool)
|
60 |
+
shared_memories = einops.repeat(memory, f"one n d_model -> ({candidates_tex_ids.shape[0]} one) n d_model")
|
61 |
+
outs = transformer.decode(tgt=candidates_tex_ids,
|
62 |
+
memory=shared_memories,
|
63 |
+
tgt_mask=tgt_mask,
|
64 |
+
memory_mask=None,
|
65 |
+
tgt_padding_mask=get_tgt_padding_mask(candidates_tex_ids))
|
66 |
+
outs = einops.rearrange(outs, 'b n prob -> b prob n')[:, :, -1]
|
67 |
+
vocab_size = outs.shape[1]
|
68 |
+
outs = F.log_softmax(outs, dim=1)
|
69 |
+
outs += einops.rearrange(candidates_log_prob, "prob -> prob ()")
|
70 |
+
outs = einops.rearrange(outs, 'b prob -> (b prob)')
|
71 |
+
candidates_log_prob, indices = torch.topk(outs, k=top_k)
|
72 |
+
|
73 |
+
new_candidates = []
|
74 |
+
for index in indices:
|
75 |
+
candidate_id, token_id = divmod(index.item(), vocab_size)
|
76 |
+
new_candidates.append(candidates_tex_ids[candidate_id].to(int).tolist() + [token_id])
|
77 |
+
candidates_tex_ids = new_candidates
|
78 |
+
|
79 |
+
candidates_tex_ids = torch.tensor(candidates_tex_ids)
|
80 |
+
padding_mask = get_tgt_padding_mask(candidates_tex_ids).cpu()
|
81 |
+
candidates_tex_ids = candidates_tex_ids.masked_fill(
|
82 |
+
padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
|
83 |
+
tex_tokenizer.token_to_id("[PAD]")).tolist()
|
84 |
+
texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
|
85 |
+
return texs, candidates_tex_ids
|