Spaces:
Runtime error
Runtime error
from constants import TOKENIZER_PATH | |
import einops | |
import random | |
from pytorch_lightning.callbacks import Callback | |
import torch | |
from torchvision import transforms | |
class LogImageTexCallback(Callback): | |
def __init__(self, logger): | |
self.logger = logger | |
self.tex_tokenizer = torch.load(TOKENIZER_PATH) | |
self.tensor_to_PIL = transforms.ToPILImage() | |
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx): | |
if batch_idx != 0 or dataloader_idx != 0: | |
return | |
sample_id = random.randint(0, len(batch['images']) - 1) | |
image = batch['images'][sample_id] | |
tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image) | |
image = self.tensor_to_PIL(image) | |
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)), | |
skip_special_tokens=True) | |
self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"]) | |
# parser.add_argument( | |
# "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False, | |
# action="store_true", dest="tune" | |
# ) | |
# if args.new_dataset: | |
# datamodule.batch_size = 1 | |
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda() | |
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu", | |
# gpus=args.gpus, | |
# strategy=TRAINER_STRATEGY, | |
# enable_progress_bar=True, | |
# enable_checkpointing=False, | |
# auto_scale_batch_size=True, | |
# num_sanity_val_steps=0, | |
# logger=False | |
# ) | |
# tuner.tune(transformer_for_tuning, datamodule=datamodule) | |
# torch.save(datamodule, DATASET_PATH) | |
# TUNER_DIR = "resources/pl_tuner_checkpoints" | |
# from pytorch_lightning import seed_everything | |
# parser.add_argument( | |
# "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False, | |
# action="store_true", dest="deterministic" | |
# ) | |
# if args.deterministic: | |
# seed_everything(42, workers=True) | |
# def generate_normalize_transform(dataset: TexImageDataset): | |
# """Returns a normalize layer with mean and std computed after iterating over dataset""" | |
# | |
# mean = 0 | |
# std = 0 | |
# for item in tqdm.tqdm(dataset, "Computing dataset image stats"): | |
# image = item['image'] | |
# mean += image.mean() | |
# std += image.std() | |
# | |
# mean /= len(dataset) | |
# std /= len(dataset) | |
# normalize = T.Normalize(mean, std) | |
# return normalize | |
# class _TransformerTuner(Transformer): | |
# """ | |
# When using trainer.tune, batches from dataloader get passed directly to forward, | |
# so this subclass takes care of that | |
# """ | |
# | |
# def forward(self, batch, batch_idx): | |
# src = batch['images'] | |
# tgt = batch['tex_ids'] | |
# tgt_input = tgt[:, :-1] | |
# tgt_output = tgt[:, 1:] | |
# src_mask = None | |
# tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device, | |
# torch.ByteTensor.dtype) | |
# memory_mask = None | |
# src_padding_mask = None | |
# tgt_padding_mask = batch['tex_attention_masks'][:, :-1] | |
# tgt_padding_mask = tgt_padding_mask.masked_fill( | |
# tgt_padding_mask == 0, float('-inf') | |
# ).masked_fill( | |
# tgt_padding_mask == 1, 0 | |
# ) | |
# | |
# src = self.src_tok_emb(src) | |
# tgt_input = self.tgt_tok_emb(tgt_input) | |
# outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask) | |
# outs = self.generator(outs) | |
# | |
# loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long()) | |
# return loss | |
# | |
# def validation_step(self, batch, batch_idx): | |
# return self(batch, batch_idx) | |
def decode(transformer, image): | |
tex_tokenizer = torch.load(TOKENIZER_PATH) | |
tex_ids = [tex_tokenizer.token_to_id("[CLS]")] | |
src = einops.rearrange(image, "c h w -> () c h w") | |
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30: | |
tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float) | |
tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device, | |
torch.bool) | |
outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask) | |
outs = einops.rearrange(outs, 'b n prob -> b prob n') | |
next_id = outs[0, :, -1].argmax().item() | |
tex_ids.append(next_id) | |
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True) | |
return tex, tex_ids | |