File size: 5,018 Bytes
c308f77
4f4785c
e932abd
 
 
 
4f4785c
 
 
 
c308f77
4f4785c
c308f77
4f4785c
 
 
 
 
e932abd
 
 
4f4785c
e932abd
 
41a34cd
e932abd
 
41a34cd
 
 
 
e932abd
41a34cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e932abd
 
 
c308f77
 
e932abd
 
 
 
 
 
 
 
 
 
 
41a34cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from constants import TOKENIZER_PATH

import einops
import random
from pytorch_lightning.callbacks import Callback
import torch
from torchvision import transforms


class LogImageTexCallback(Callback):
    def __init__(self, logger):
        self.logger = logger
        self.tex_tokenizer = torch.load(TOKENIZER_PATH)
        self.tensor_to_PIL = transforms.ToPILImage()

    def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
        if batch_idx != 0 or dataloader_idx != 0:
            return
        sample_id = random.randint(0, len(batch['images']) - 1)
        image = batch['images'][sample_id]
        tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image)
        image = self.tensor_to_PIL(image)
        tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
                                             skip_special_tokens=True)
        self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"])


# parser.add_argument(
#     "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
#     action="store_true", dest="tune"
# )

# if args.new_dataset:
#     datamodule.batch_size = 1
#     transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
#     tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
#                     gpus=args.gpus,
#                     strategy=TRAINER_STRATEGY,
#                     enable_progress_bar=True,
#                     enable_checkpointing=False,
#                     auto_scale_batch_size=True,
#                     num_sanity_val_steps=0,
#                     logger=False
#                     )
#     tuner.tune(transformer_for_tuning, datamodule=datamodule)
#     torch.save(datamodule, DATASET_PATH)
# TUNER_DIR = "resources/pl_tuner_checkpoints"
# from pytorch_lightning import  seed_everything
#     parser.add_argument(
#         "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
#         action="store_true", dest="deterministic"
#     )
#     if args.deterministic:
#         seed_everything(42, workers=True)
# def generate_normalize_transform(dataset: TexImageDataset):
#     """Returns a normalize layer with mean and std computed after iterating over dataset"""
#
#     mean = 0
#     std = 0
#     for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
#         image = item['image']
#         mean += image.mean()
#         std += image.std()
#
#     mean /= len(dataset)
#     std /= len(dataset)
#     normalize = T.Normalize(mean, std)
#     return normalize
# class _TransformerTuner(Transformer):
#     """
#     When using trainer.tune, batches from dataloader get passed directly to forward,
#     so this subclass takes care of that
#     """
#
#     def forward(self, batch, batch_idx):
#         src = batch['images']
#         tgt = batch['tex_ids']
#         tgt_input = tgt[:, :-1]
#         tgt_output = tgt[:, 1:]
#         src_mask = None
#         tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
#                                                                                            torch.ByteTensor.dtype)
#         memory_mask = None
#         src_padding_mask = None
#         tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
#         tgt_padding_mask = tgt_padding_mask.masked_fill(
#             tgt_padding_mask == 0, float('-inf')
#         ).masked_fill(
#             tgt_padding_mask == 1, 0
#         )
#
#         src = self.src_tok_emb(src)
#         tgt_input = self.tgt_tok_emb(tgt_input)
#         outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
#         outs = self.generator(outs)
#
#         loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
#         return loss
#
#     def validation_step(self, batch, batch_idx):
#         return self(batch, batch_idx)


@torch.inference_mode()
def decode(transformer, image):
    tex_tokenizer = torch.load(TOKENIZER_PATH)
    tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
    src = einops.rearrange(image, "c h w -> () c h w")
    while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
        tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float)
        tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device,
                                                                                            torch.bool)
        outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
        outs = einops.rearrange(outs, 'b n prob -> b prob n')
        next_id = outs[0, :, -1].argmax().item()
        tex_ids.append(next_id)
    tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
    return tex, tex_ids