File size: 1,741 Bytes
6497501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image


class ParseqPredictor(nn.Module):

    def __init__(self, ckpt_path=None, freeze=True, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.parseq = torch.hub.load('./src/parseq', 'parseq', source='local').eval()
        self.parseq.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
        self.parseq_transform = transforms.Compose([
            transforms.Resize(self.parseq.hparams.img_size, transforms.InterpolationMode.BICUBIC, antialias=True),
            transforms.Normalize(0.5, 0.5)
        ])

        if freeze:
            self.freeze()

    def freeze(self):
        for param in self.parseq.parameters():
            param.requires_grad_(False) 

    def forward(self, x):
        
        x = torch.cat([self.parseq_transform(t[None]) for t in x])
        logits = self.parseq(x.to(next(self.parameters()).device))

        return logits

    def img2txt(self, x):

        pred = self(x)
        label, confidence = self.parseq.tokenizer.decode(pred)
        return label

    
    def calc_loss(self, x, label):

        preds = self(x)  # (B, l, C) l=26, C=95
        gt_ids = self.parseq.tokenizer.encode(label).to(preds.device) # (B, l_trun)

        losses = []
        for pred, gt_id in zip(preds, gt_ids):

            eos_id = (gt_id == 0).nonzero().item()
            gt_id = gt_id[1: eos_id]
            pred = pred[:eos_id-1, :]

            ce_loss = nn.functional.cross_entropy(pred.permute(1, 0)[None], gt_id[None])
            ce_loss = torch.clamp(ce_loss, max = 1.0)
            losses.append(ce_loss[None])

        loss = torch.cat(losses)

        return loss