from constants import TOKENIZER_PATH import einops import random from pytorch_lightning.callbacks import Callback import torch from torchvision import transforms class LogImageTexCallback(Callback): def __init__(self, logger): self.logger = logger self.tex_tokenizer = torch.load(TOKENIZER_PATH) self.tensor_to_PIL = transforms.ToPILImage() def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx): if batch_idx != 0 or dataloader_idx != 0: return sample_id = random.randint(0, len(batch['images']) - 1) image = batch['images'][sample_id] tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image) image = self.tensor_to_PIL(image) tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)), skip_special_tokens=True) self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"]) # parser.add_argument( # "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False, # action="store_true", dest="tune" # ) # if args.new_dataset: # datamodule.batch_size = 1 # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda() # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu", # gpus=args.gpus, # strategy=TRAINER_STRATEGY, # enable_progress_bar=True, # enable_checkpointing=False, # auto_scale_batch_size=True, # num_sanity_val_steps=0, # logger=False # ) # tuner.tune(transformer_for_tuning, datamodule=datamodule) # torch.save(datamodule, DATASET_PATH) # TUNER_DIR = "resources/pl_tuner_checkpoints" # from pytorch_lightning import seed_everything # parser.add_argument( # "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False, # action="store_true", dest="deterministic" # ) # if args.deterministic: # seed_everything(42, workers=True) # def generate_normalize_transform(dataset: TexImageDataset): # """Returns a normalize layer with mean and std computed after iterating over dataset""" # # mean = 0 # std = 0 # for item in tqdm.tqdm(dataset, "Computing dataset image stats"): # image = item['image'] # mean += image.mean() # std += image.std() # # mean /= len(dataset) # std /= len(dataset) # normalize = T.Normalize(mean, std) # return normalize # class _TransformerTuner(Transformer): # """ # When using trainer.tune, batches from dataloader get passed directly to forward, # so this subclass takes care of that # """ # # def forward(self, batch, batch_idx): # src = batch['images'] # tgt = batch['tex_ids'] # tgt_input = tgt[:, :-1] # tgt_output = tgt[:, 1:] # src_mask = None # tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device, # torch.ByteTensor.dtype) # memory_mask = None # src_padding_mask = None # tgt_padding_mask = batch['tex_attention_masks'][:, :-1] # tgt_padding_mask = tgt_padding_mask.masked_fill( # tgt_padding_mask == 0, float('-inf') # ).masked_fill( # tgt_padding_mask == 1, 0 # ) # # src = self.src_tok_emb(src) # tgt_input = self.tgt_tok_emb(tgt_input) # outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask) # outs = self.generator(outs) # # loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long()) # return loss # # def validation_step(self, batch, batch_idx): # return self(batch, batch_idx) @torch.inference_mode() def decode(transformer, image): tex_tokenizer = torch.load(TOKENIZER_PATH) tex_ids = [tex_tokenizer.token_to_id("[CLS]")] src = einops.rearrange(image, "c h w -> () c h w") while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30: tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float) tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device, torch.bool) outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask) outs = einops.rearrange(outs, 'b n prob -> b prob n') next_id = outs[0, :, -1].argmax().item() tex_ids.append(next_id) tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True) return tex, tex_ids