File size: 3,851 Bytes
0df4363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn

from sat.model import ViTModel, BaseModel
from sat.model import BaseMixin
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

class LNFinalyMixin(BaseMixin):
    def __init__(self, hidden_size):
        super().__init__()
        self.ln_vision = nn.LayerNorm(hidden_size)

    def final_forward(self, logits, **kw_args):
        return self.ln_vision(logits)


class EVAViT(ViTModel):
    def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
        super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
        self.del_mixin("cls")
        self.add_mixin("cls", LNFinalyMixin(args.hidden_size))

    def forward(self, image):
        batch_size = image.size(0)
        input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device)
        attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device)
        return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)


class QFormer(BaseModel):
    def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
        super().__init__(args, transformer=transformer, parallel_output=parallel_output,
                         activation_func=nn.functional.gelu, **kwargs)
        self.transformer.position_embeddings = None

    def final_forward(self, logits, **kw_args):
        return logits

    def position_embedding_forward(self, position_ids, **kw_args):
        return None

    def forward(self, encoder_outputs):
        batch_size = encoder_outputs.size(0)
        input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size,
                                                                                                          -1)
        attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
        cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
        return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask,
                               encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask)


class BLIP2(torch.nn.Module):
    def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs):
        super().__init__()
        if vit is not None:
            self.vit = vit
        else:
            self.vit = EVAViT(EVAViT.get_args(**eva_args))
        if qformer is not None:
            self.qformer = qformer
        else:
            self.qformer = QFormer(QFormer.get_args(**qformer_args))

        self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(
            self.qformer.parameters().__next__().dtype)

    def forward(self, image, **kwargs):
        enc = self.vit(image)[0]
        out = self.qformer(enc)[0]
        return self.glm_proj(out)


class BlipImageBaseProcessor():
    def __init__(self, mean=None, std=None):
        if mean is None:
            mean = (0.48145466, 0.4578275, 0.40821073)
        if std is None:
            std = (0.26862954, 0.26130258, 0.27577711)

        self.normalize = transforms.Normalize(mean, std)


class BlipImageEvalProcessor(BlipImageBaseProcessor):
    def __init__(self, image_size=384, mean=None, std=None):
        super().__init__(mean=mean, std=std)

        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def __call__(self, item):
        return self.transform(item)