ML2TransformerApp / utils.py
dkoshman
moved constants to separate file, organized tokenizer
c308f77
raw
history blame
No virus
5.02 kB
from constants import TOKENIZER_PATH
import einops
import random
from pytorch_lightning.callbacks import Callback
import torch
from torchvision import transforms
class LogImageTexCallback(Callback):
def __init__(self, logger):
self.logger = logger
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
self.tensor_to_PIL = transforms.ToPILImage()
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
if batch_idx != 0 or dataloader_idx != 0:
return
sample_id = random.randint(0, len(batch['images']) - 1)
image = batch['images'][sample_id]
tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image)
image = self.tensor_to_PIL(image)
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
skip_special_tokens=True)
self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"])
# parser.add_argument(
# "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
# action="store_true", dest="tune"
# )
# if args.new_dataset:
# datamodule.batch_size = 1
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
# gpus=args.gpus,
# strategy=TRAINER_STRATEGY,
# enable_progress_bar=True,
# enable_checkpointing=False,
# auto_scale_batch_size=True,
# num_sanity_val_steps=0,
# logger=False
# )
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
# torch.save(datamodule, DATASET_PATH)
# TUNER_DIR = "resources/pl_tuner_checkpoints"
# from pytorch_lightning import seed_everything
# parser.add_argument(
# "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
# action="store_true", dest="deterministic"
# )
# if args.deterministic:
# seed_everything(42, workers=True)
# def generate_normalize_transform(dataset: TexImageDataset):
# """Returns a normalize layer with mean and std computed after iterating over dataset"""
#
# mean = 0
# std = 0
# for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
# image = item['image']
# mean += image.mean()
# std += image.std()
#
# mean /= len(dataset)
# std /= len(dataset)
# normalize = T.Normalize(mean, std)
# return normalize
# class _TransformerTuner(Transformer):
# """
# When using trainer.tune, batches from dataloader get passed directly to forward,
# so this subclass takes care of that
# """
#
# def forward(self, batch, batch_idx):
# src = batch['images']
# tgt = batch['tex_ids']
# tgt_input = tgt[:, :-1]
# tgt_output = tgt[:, 1:]
# src_mask = None
# tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
# torch.ByteTensor.dtype)
# memory_mask = None
# src_padding_mask = None
# tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
# tgt_padding_mask = tgt_padding_mask.masked_fill(
# tgt_padding_mask == 0, float('-inf')
# ).masked_fill(
# tgt_padding_mask == 1, 0
# )
#
# src = self.src_tok_emb(src)
# tgt_input = self.tgt_tok_emb(tgt_input)
# outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
# outs = self.generator(outs)
#
# loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
# return loss
#
# def validation_step(self, batch, batch_idx):
# return self(batch, batch_idx)
@torch.inference_mode()
def decode(transformer, image):
tex_tokenizer = torch.load(TOKENIZER_PATH)
tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
src = einops.rearrange(image, "c h w -> () c h w")
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float)
tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device,
torch.bool)
outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
outs = einops.rearrange(outs, 'b n prob -> b prob n')
next_id = outs[0, :, -1].argmax().item()
tex_ids.append(next_id)
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
return tex, tex_ids