File size: 4,777 Bytes
c308f77
4f4785c
e932abd
96feb73
e932abd
 
 
29bcc5f
4f4785c
 
 
 
29bcc5f
4f4785c
29bcc5f
 
c308f77
4f4785c
 
 
 
 
e932abd
 
96feb73
4f4785c
29bcc5f
 
 
e932abd
 
29bcc5f
96feb73
29bcc5f
e932abd
29bcc5f
 
 
 
 
e932abd
96feb73
 
 
 
29bcc5f
 
e932abd
c308f77
29bcc5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96feb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from constants import TOKENIZER_PATH

import einops
import os
import random
from pytorch_lightning.callbacks import Callback
import torch
import torch.nn.functional as F
from torchvision import transforms


class LogImageTexCallback(Callback):
    def __init__(self, logger, top_k, max_length):
        self.logger = logger
        self.top_k = top_k
        self.max_length = max_length
        self.tex_tokenizer = torch.load(TOKENIZER_PATH)
        self.tensor_to_PIL = transforms.ToPILImage()

    def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
        if batch_idx != 0 or dataloader_idx != 0:
            return
        sample_id = random.randint(0, len(batch['images']) - 1)
        image = batch['images'][sample_id]
        texs_predicted = beam_search_decode(transformer, image, top_k=self.top_k, max_length=self.max_length)
        image = self.tensor_to_PIL(image)
        tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
        self.logger.log_image(key="samples", images=[image],
                              caption=[f"True: {tex_true}\nPredicted: " + "\n".join(texs_predicted)])


@torch.inference_mode()
def beam_search_decode(transformer, image, image_transform=None, top_k=10, max_length=100):
    """Performs decoding maintaining k best candidates"""

    def get_tgt_padding_mask(tgt):
        mask = tgt == tex_tokenizer.token_to_id("[SEP]")
        mask = torch.cumsum(mask, dim=1)
        mask = mask.to(transformer.device, torch.bool)
        return mask

    if image_transform:
        image = image_transform(image)

    assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
    src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
    memory = transformer.encode(src)

    tex_tokenizer = torch.load(TOKENIZER_PATH)
    candidates_tex_ids = [[tex_tokenizer.token_to_id("[CLS]")]]
    candidates_log_prob = torch.tensor([0], dtype=torch.float, device=transformer.device)

    while candidates_tex_ids[0][-1] != tex_tokenizer.token_to_id("[SEP]") and len(candidates_tex_ids[0]) < max_length:
        candidates_tex_ids = torch.tensor(candidates_tex_ids, dtype=torch.float, device=transformer.device)
        tgt_mask = transformer.transformer.generate_square_subsequent_mask(candidates_tex_ids.shape[1]).to(
            transformer.device, torch.bool)
        shared_memories = einops.repeat(memory, f"one n d_model -> ({candidates_tex_ids.shape[0]} one) n d_model")
        outs = transformer.decode(tgt=candidates_tex_ids,
                                  memory=shared_memories,
                                  tgt_mask=tgt_mask,
                                  memory_mask=None,
                                  tgt_padding_mask=get_tgt_padding_mask(candidates_tex_ids))
        outs = einops.rearrange(outs, 'b n prob -> b prob n')[:, :, -1]
        vocab_size = outs.shape[1]
        outs = F.log_softmax(outs, dim=1)
        outs += einops.rearrange(candidates_log_prob, "prob -> prob ()")
        outs = einops.rearrange(outs, 'b prob -> (b prob)')
        candidates_log_prob, indices = torch.topk(outs, k=top_k)

        new_candidates = []
        for index in indices:
            candidate_id, token_id = divmod(index.item(), vocab_size)
            new_candidates.append(candidates_tex_ids[candidate_id].to(int).tolist() + [token_id])
        candidates_tex_ids = new_candidates

    candidates_tex_ids = torch.tensor(candidates_tex_ids)
    padding_mask = get_tgt_padding_mask(candidates_tex_ids).cpu()
    candidates_tex_ids = candidates_tex_ids.masked_fill(
        padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
        tex_tokenizer.token_to_id("[PAD]")).tolist()
    texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
    texs = [tex.replace("\\ ", "\\") for tex in texs]
    return texs


def average_checkpoints(model_type, checkpoints_dir):
    """Returns model averaged from checkpoints
    Args:
        :model_type: -- pytorch_lightning.LightningModule that corresponds to checkpoints
        :checkpoints_dir: -- path to checkpoints
    """
    checkpoints = [checkpoint.path for checkpoint in os.scandir(checkpoints_dir)]
    n_models = len(checkpoints)
    assert n_models > 0
    average_model = model_type.load_from_checkpoint(checkpoints[0])

    for checkpoint in checkpoints[1:]:
        model = model_type.load_from_checkpoint(checkpoint)
        for weight, weight_to_add in zip(average_model.parameters(), model.parameters()):
            weight.data.add_(weight_to_add.data)

    for weight in average_model.parameters():
        weight.data.divide_(n_models)

    return average_model