dkoshman commited on
Commit
29bcc5f
1 Parent(s): c308f77

beam search

Browse files
Files changed (5) hide show
  1. constants.py +1 -0
  2. data_preprocessing.py +2 -3
  3. model.py +13 -1
  4. train.py +67 -39
  5. utils.py +59 -92
constants.py CHANGED
@@ -5,6 +5,7 @@ DATA_DIR = "data"
5
  LATEX_PATH = "resources/latex.json"
6
  TRAINER_DIR = "resources/trainer"
7
  TOKENIZER_PATH = "resources/tokenizer.pt"
 
8
 
9
  NUM_DATALOADER_WORKERS = 4
10
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
 
5
  LATEX_PATH = "resources/latex.json"
6
  TRAINER_DIR = "resources/trainer"
7
  TOKENIZER_PATH = "resources/tokenizer.pt"
8
+ DATAMODULE_PATH = "resources/datamodule.pt"
9
 
10
  NUM_DATALOADER_WORKERS = 4
11
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
data_preprocessing.py CHANGED
@@ -119,7 +119,6 @@ def generate_tex_tokenizer(dataloader):
119
 
120
  texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
121
 
122
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
123
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
124
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
125
  special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
@@ -150,10 +149,10 @@ class LatexImageDataModule(pl.LightningDataModule):
150
  self.batch_size = batch_size
151
  self.save_hyperparameters()
152
 
153
- def prepare_data(self):
154
  tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
155
- print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
156
  torch.save(tokenizer, TOKENIZER_PATH)
 
157
 
158
  def _shared_dataloader(self, dataset, **kwargs):
159
  tex_tokenizer = torch.load(TOKENIZER_PATH)
 
119
 
120
  texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
121
 
 
122
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
123
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
124
  special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
 
149
  self.batch_size = batch_size
150
  self.save_hyperparameters()
151
 
152
+ def train_tokenizer(self):
153
  tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
 
154
  torch.save(tokenizer, TOKENIZER_PATH)
155
+ return tokenizer
156
 
157
  def _shared_dataloader(self, dataset, **kwargs):
158
  tex_tokenizer = torch.load(TOKENIZER_PATH)
model.py CHANGED
@@ -130,12 +130,24 @@ class Transformer(pl.LightningModule):
130
 
131
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
132
  tgt_padding_mask=None):
 
 
 
 
133
  src = self.src_tok_emb(src)
134
  tgt = self.tgt_tok_emb(tgt)
135
-
136
  outs = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
137
  return self.generator(outs)
138
 
 
 
 
 
 
 
 
 
 
139
  def _shared_step(self, batch):
140
  src = batch['images']
141
  tgt = batch['tex_ids']
 
130
 
131
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
132
  tgt_padding_mask=None):
133
+ """The positions of masks with ``True``
134
+ are not allowed to attend while ``False`` values will be unchanged.
135
+ The positions of padding masks with the
136
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged."""
137
  src = self.src_tok_emb(src)
138
  tgt = self.tgt_tok_emb(tgt)
 
139
  outs = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
140
  return self.generator(outs)
141
 
142
+ def encode(self, src, src_mask=None, src_padding_mask=None):
143
+ src = self.src_tok_emb(src)
144
+ return self.transformer.encoder(src, src_mask, src_padding_mask)
145
+
146
+ def decode(self, tgt, memory=None, tgt_mask=None, memory_mask=None, tgt_padding_mask=None):
147
+ tgt = self.tgt_tok_emb(tgt)
148
+ outs = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask, tgt_padding_mask)
149
+ return self.generator(outs)
150
+
151
  def _shared_step(self, batch):
152
  src = batch['images']
153
  tgt = batch['tex_ids']
train.py CHANGED
@@ -1,4 +1,4 @@
1
- from constants import TRAINER_DIR
2
  from data_preprocessing import LatexImageDataModule
3
  from model import Transformer
4
  from utils import LogImageTexCallback
@@ -11,77 +11,105 @@ from pytorch_lightning import Trainer
11
  import torch
12
 
13
 
14
- # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
15
- # crop image, adjust brightness, make tex tokens always decodable,
16
- # save only datamodule state?, ensemble last checkpoints, early stopping
 
 
 
 
 
 
 
 
 
17
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
20
 
21
- parser.add_argument("-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs")
22
- parser.add_argument("-g", "-gpus", metavar="GPUS", type=int, choices=list(range(torch.cuda.device_count())),
23
- help="ids of gpus to train on, if not provided, then trains on cpu", nargs="+", dest="gpus")
24
- parser.add_argument("-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
25
  action="store_true", dest="log")
26
- parser.add_argument("-width", help="width of images, default 1024", default=1024, type=int)
27
- parser.add_argument("-height", help="height of images, default 128", default=128, type=int)
28
- parser.add_argument("-r", "-randomize", default=5, type=int, dest="random_magnitude", choices=range(10),
29
- help="add random augments to images of provided magnitude in range 0..9, default 5")
30
- parser.add_argument("-b", "-batch-size", help="batch size, default 16", default=16,
31
- type=int, dest="batch_size")
 
 
 
32
  transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
33
  ("dim_feedforward", 2048), ("dropout", 0.1)]
34
- parser.add_argument("-t", "-transformer-args", dest="transformer_args", nargs='+', default=[],
35
- help="transformer init args:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
 
36
 
37
  args = parser.parse_args()
38
- for i, parameter in enumerate(args.transformer_args):
39
- transformer_args[i][1] = parameter
40
- args.transformer_args = dict(transformer_args)
 
 
 
 
 
 
41
  return args
42
 
43
 
44
  def main():
 
45
  args = parse_args()
46
- datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
47
- batch_size=args.batch_size, random_magnitude=args.random_magnitude)
48
- datamodule.prepare_data()
 
 
 
 
 
 
 
49
 
 
 
 
 
50
  if args.log:
51
  logger = WandbLogger(f"img2tex", log_model=True)
52
- callbacks = [LogImageTexCallback(logger),
53
- LearningRateMonitor(logging_interval='step'),
54
  ModelCheckpoint(save_top_k=10,
55
  monitor="val_loss",
56
  mode="min",
57
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
58
- else:
59
- logger = None
60
- callbacks = []
61
 
62
  trainer = Trainer(max_epochs=args.max_epochs,
63
  accelerator="cpu" if args.gpus is None else "gpu",
64
  gpus=args.gpus,
65
  logger=logger,
66
- strategy="ddp",
67
  enable_progress_bar=True,
68
  default_root_dir=TRAINER_DIR,
69
  callbacks=callbacks,
70
  check_val_every_n_epoch=5)
71
 
72
- transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
73
- num_decoder_layers=args.transformer_args['num_decoder_layers'],
74
- d_model=args.transformer_args['d_model'],
75
- nhead=args.transformer_args['nhead'],
76
- dim_feedforward=args.transformer_args['dim_feedforward'],
77
- dropout=args.transformer_args['dropout'],
78
- image_width=datamodule.hparams['image_width'],
79
- image_height=datamodule.hparams['image_height'],
80
- tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
81
- pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
82
 
83
  trainer.fit(transformer, datamodule=datamodule)
84
- trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
85
 
86
 
87
  if __name__ == "__main__":
 
1
+ from constants import TRAINER_DIR, TOKENIZER_PATH, DATAMODULE_PATH
2
  from data_preprocessing import LatexImageDataModule
3
  from model import Transformer
4
  from utils import LogImageTexCallback
 
11
  import torch
12
 
13
 
14
+ # TODO: update python, make tex tokens always decodable, ensemble last checkpoints,
15
+ # clear checkpoint data build full dataset, train export model to torchscript write spaces interface
16
+
17
+ def check_setup():
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+ if not os.path.isfile(DATAMODULE_PATH):
20
+ datamodule = LatexImageDataModule(image_width=1024, image_height=128, batch_size=16, random_magnitude=5)
21
+ torch.save(datamodule, DATAMODULE_PATH)
22
+ if not os.path.isfile(TOKENIZER_PATH):
23
+ datamodule = torch.load(DATAMODULE_PATH)
24
+ datamodule.train_tokenizer()
25
+
26
 
27
  def parse_args():
28
  parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
29
 
30
+ parser.add_argument("gpus", type=int, default=None,
31
+ help=f"Ids of gpus in range 0..{torch.cuda.device_count()} to train on, "
32
+ "if not provided, then trains on cpu", nargs="*")
33
+ parser.add_argument("-l", "-log", help="Whether to save logs of run to w&b logger, default False", default=False,
34
  action="store_true", dest="log")
35
+ parser.add_argument("-m", "-max-epochs", help="Limit the number of training epochs", type=int, dest="max_epochs")
36
+
37
+ datamodule_args = ["image_width", "image_height", "batch_size", "random_magnitude"]
38
+
39
+ datamodule = torch.load(DATAMODULE_PATH)
40
+ parser.add_argument("-d", metavar="X", nargs=4, dest="datamodule_args", type=int,
41
+ help="Create new datamodule and exit, current parameters:\n" +
42
+ "\n".join(f"{arg}\t{datamodule.hparams[arg]}" for arg in datamodule_args))
43
+
44
  transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
45
  ("dim_feedforward", 2048), ("dropout", 0.1)]
46
+ parser.add_argument("-t", metavar="X", dest="transformer_args", nargs=len(transformer_args),
47
+ help="Transformer init args, reference values:\n" +
48
+ "\n".join(f"{k}\t{v}" for k, v in transformer_args))
49
 
50
  args = parser.parse_args()
51
+
52
+ if args.datamodule_args:
53
+ args.datamodule_args = dict(zip(datamodule_args, args.datamodule_args))
54
+
55
+ if args.transformer_args:
56
+ args.transformer_args = dict(zip(list(zip(*transformer_args))[0], args.transformer_args))
57
+ else:
58
+ args.transformer_args = dict(transformer_args)
59
+
60
  return args
61
 
62
 
63
  def main():
64
+ check_setup()
65
  args = parse_args()
66
+ if args.datamodule_args:
67
+ datamodule = LatexImageDataModule(image_width=args.datamodule_args["image_width"],
68
+ image_height=args.datamodule_args["image_height"],
69
+ batch_size=args.datamodule_args["batch_size"],
70
+ random_magnitude=args.datamodule_args["random_magnitude"])
71
+ datamodule.train_tokenizer()
72
+ tex_tokenizer = torch.load(TOKENIZER_PATH)
73
+ print(f"Vocabulary size {tex_tokenizer.get_vocab_size()}")
74
+ torch.save(datamodule, DATAMODULE_PATH)
75
+ return
76
 
77
+ datamodule = torch.load(DATAMODULE_PATH)
78
+ tex_tokenizer = torch.load(TOKENIZER_PATH)
79
+ logger = None
80
+ callbacks = []
81
  if args.log:
82
  logger = WandbLogger(f"img2tex", log_model=True)
83
+ callbacks = [LogImageTexCallback(logger, top_k=10, max_length=20),
84
+ LearningRateMonitor(logging_interval="step"),
85
  ModelCheckpoint(save_top_k=10,
86
  monitor="val_loss",
87
  mode="min",
88
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
 
 
 
89
 
90
  trainer = Trainer(max_epochs=args.max_epochs,
91
  accelerator="cpu" if args.gpus is None else "gpu",
92
  gpus=args.gpus,
93
  logger=logger,
94
+ strategy="ddp_find_unused_parameters_false",
95
  enable_progress_bar=True,
96
  default_root_dir=TRAINER_DIR,
97
  callbacks=callbacks,
98
  check_val_every_n_epoch=5)
99
 
100
+ transformer = Transformer(num_encoder_layers=args.transformer_args["num_encoder_layers"],
101
+ num_decoder_layers=args.transformer_args["num_decoder_layers"],
102
+ d_model=args.transformer_args["d_model"],
103
+ nhead=args.transformer_args["nhead"],
104
+ dim_feedforward=args.transformer_args["dim_feedforward"],
105
+ dropout=args.transformer_args["dropout"],
106
+ image_width=datamodule.hparams["image_width"],
107
+ image_height=datamodule.hparams["image_height"],
108
+ tgt_vocab_size=tex_tokenizer.get_vocab_size(),
109
+ pad_idx=tex_tokenizer.token_to_id("[PAD]"))
110
 
111
  trainer.fit(transformer, datamodule=datamodule)
112
+ trainer.test(transformer, datamodule=datamodule)
113
 
114
 
115
  if __name__ == "__main__":
utils.py CHANGED
@@ -1,15 +1,19 @@
1
  from constants import TOKENIZER_PATH
 
2
 
3
  import einops
4
  import random
5
  from pytorch_lightning.callbacks import Callback
6
  import torch
 
7
  from torchvision import transforms
8
 
9
 
10
  class LogImageTexCallback(Callback):
11
- def __init__(self, logger):
12
  self.logger = logger
 
 
13
  self.tex_tokenizer = torch.load(TOKENIZER_PATH)
14
  self.tensor_to_PIL = transforms.ToPILImage()
15
 
@@ -18,101 +22,64 @@ class LogImageTexCallback(Callback):
18
  return
19
  sample_id = random.randint(0, len(batch['images']) - 1)
20
  image = batch['images'][sample_id]
21
- tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image)
 
22
  image = self.tensor_to_PIL(image)
23
- tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
24
- skip_special_tokens=True)
25
- self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"])
26
 
27
 
28
- # parser.add_argument(
29
- # "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
30
- # action="store_true", dest="tune"
31
- # )
32
 
33
- # if args.new_dataset:
34
- # datamodule.batch_size = 1
35
- # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
36
- # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
37
- # gpus=args.gpus,
38
- # strategy=TRAINER_STRATEGY,
39
- # enable_progress_bar=True,
40
- # enable_checkpointing=False,
41
- # auto_scale_batch_size=True,
42
- # num_sanity_val_steps=0,
43
- # logger=False
44
- # )
45
- # tuner.tune(transformer_for_tuning, datamodule=datamodule)
46
- # torch.save(datamodule, DATASET_PATH)
47
- # TUNER_DIR = "resources/pl_tuner_checkpoints"
48
- # from pytorch_lightning import seed_everything
49
- # parser.add_argument(
50
- # "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
51
- # action="store_true", dest="deterministic"
52
- # )
53
- # if args.deterministic:
54
- # seed_everything(42, workers=True)
55
- # def generate_normalize_transform(dataset: TexImageDataset):
56
- # """Returns a normalize layer with mean and std computed after iterating over dataset"""
57
- #
58
- # mean = 0
59
- # std = 0
60
- # for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
61
- # image = item['image']
62
- # mean += image.mean()
63
- # std += image.std()
64
- #
65
- # mean /= len(dataset)
66
- # std /= len(dataset)
67
- # normalize = T.Normalize(mean, std)
68
- # return normalize
69
- # class _TransformerTuner(Transformer):
70
- # """
71
- # When using trainer.tune, batches from dataloader get passed directly to forward,
72
- # so this subclass takes care of that
73
- # """
74
- #
75
- # def forward(self, batch, batch_idx):
76
- # src = batch['images']
77
- # tgt = batch['tex_ids']
78
- # tgt_input = tgt[:, :-1]
79
- # tgt_output = tgt[:, 1:]
80
- # src_mask = None
81
- # tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
82
- # torch.ByteTensor.dtype)
83
- # memory_mask = None
84
- # src_padding_mask = None
85
- # tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
86
- # tgt_padding_mask = tgt_padding_mask.masked_fill(
87
- # tgt_padding_mask == 0, float('-inf')
88
- # ).masked_fill(
89
- # tgt_padding_mask == 1, 0
90
- # )
91
- #
92
- # src = self.src_tok_emb(src)
93
- # tgt_input = self.tgt_tok_emb(tgt_input)
94
- # outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
95
- # outs = self.generator(outs)
96
- #
97
- # loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
98
- # return loss
99
- #
100
- # def validation_step(self, batch, batch_idx):
101
- # return self(batch, batch_idx)
102
 
 
 
 
 
 
 
 
103
 
104
- @torch.inference_mode()
105
- def decode(transformer, image):
106
  tex_tokenizer = torch.load(TOKENIZER_PATH)
107
- tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
108
- src = einops.rearrange(image, "c h w -> () c h w")
109
- while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
110
- tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float)
111
- tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device,
112
- torch.bool)
113
- outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
114
- outs = einops.rearrange(outs, 'b n prob -> b prob n')
115
- next_id = outs[0, :, -1].argmax().item()
116
- tex_ids.append(next_id)
117
- tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
118
- return tex, tex_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from constants import TOKENIZER_PATH
2
+ from data_preprocessing import RandomizeImageTransform
3
 
4
  import einops
5
  import random
6
  from pytorch_lightning.callbacks import Callback
7
  import torch
8
+ import torch.nn.functional as F
9
  from torchvision import transforms
10
 
11
 
12
  class LogImageTexCallback(Callback):
13
+ def __init__(self, logger, top_k, max_length):
14
  self.logger = logger
15
+ self.top_k = top_k
16
+ self.max_length = max_length
17
  self.tex_tokenizer = torch.load(TOKENIZER_PATH)
18
  self.tensor_to_PIL = transforms.ToPILImage()
19
 
 
22
  return
23
  sample_id = random.randint(0, len(batch['images']) - 1)
24
  image = batch['images'][sample_id]
25
+ texs_predicted, texs_ids = beam_search_decode(transformer, image, transform_image=False, top_k=self.top_k,
26
+ max_length=self.max_length)
27
  image = self.tensor_to_PIL(image)
28
+ tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
29
+ self.logger.log_image(key="samples", images=[image],
30
+ caption=[f"True: {tex_true}\nPredicted: " + "\n".join(texs_predicted)])
31
 
32
 
33
+ @torch.inference_mode()
34
+ def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_length=100):
35
+ """Performs decoding maintaining k best candidates"""
36
+ assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
37
 
38
+ def get_tgt_padding_mask(tgt):
39
+ mask = tgt == tex_tokenizer.token_to_id("[SEP]")
40
+ mask = torch.cumsum(mask, dim=1)
41
+ mask = mask.to(transformer.device, torch.bool)
42
+ return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
45
+ if transform_image:
46
+ image_transform = RandomizeImageTransform(width=transformer.hparams["image_width"],
47
+ height=transformer.hparams["image_width"],
48
+ random_magnitude=0)
49
+ src = image_transform(src)
50
+ memory = transformer.encode(src)
51
 
 
 
52
  tex_tokenizer = torch.load(TOKENIZER_PATH)
53
+ candidates_tex_ids = [[tex_tokenizer.token_to_id("[CLS]")]]
54
+ candidates_log_prob = torch.tensor([0], dtype=torch.float, device=transformer.device)
55
+
56
+ while candidates_tex_ids[0][-1] != tex_tokenizer.token_to_id("[SEP]") and len(candidates_tex_ids[0]) < max_length:
57
+ candidates_tex_ids = torch.tensor(candidates_tex_ids, dtype=torch.float, device=transformer.device)
58
+ tgt_mask = transformer.transformer.generate_square_subsequent_mask(candidates_tex_ids.shape[1]).to(
59
+ transformer.device, torch.bool)
60
+ shared_memories = einops.repeat(memory, f"one n d_model -> ({candidates_tex_ids.shape[0]} one) n d_model")
61
+ outs = transformer.decode(tgt=candidates_tex_ids,
62
+ memory=shared_memories,
63
+ tgt_mask=tgt_mask,
64
+ memory_mask=None,
65
+ tgt_padding_mask=get_tgt_padding_mask(candidates_tex_ids))
66
+ outs = einops.rearrange(outs, 'b n prob -> b prob n')[:, :, -1]
67
+ vocab_size = outs.shape[1]
68
+ outs = F.log_softmax(outs, dim=1)
69
+ outs += einops.rearrange(candidates_log_prob, "prob -> prob ()")
70
+ outs = einops.rearrange(outs, 'b prob -> (b prob)')
71
+ candidates_log_prob, indices = torch.topk(outs, k=top_k)
72
+
73
+ new_candidates = []
74
+ for index in indices:
75
+ candidate_id, token_id = divmod(index.item(), vocab_size)
76
+ new_candidates.append(candidates_tex_ids[candidate_id].to(int).tolist() + [token_id])
77
+ candidates_tex_ids = new_candidates
78
+
79
+ candidates_tex_ids = torch.tensor(candidates_tex_ids)
80
+ padding_mask = get_tgt_padding_mask(candidates_tex_ids).cpu()
81
+ candidates_tex_ids = candidates_tex_ids.masked_fill(
82
+ padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
83
+ tex_tokenizer.token_to_id("[PAD]")).tolist()
84
+ texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
85
+ return texs, candidates_tex_ids