File size: 5,763 Bytes
1cc900c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from transformers import CLIPModel
import torch
from typing import Optional, Tuple


def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


def clip_loss(logits_per_text: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(logits_per_text)
    image_loss = contrastive_loss(logits_per_text.T)
    return (caption_loss + image_loss) / 2.0


class ClipMDModel(CLIPModel):

    def embed_text(self,
                   input_ids:torch.LongTensor,
                   attention_mask:torch.LongTensor,
                   output_attentions: Optional[bool] = None,
                   output_hidden_states: Optional[bool] = None,
                   position_ids: Optional[torch.LongTensor] = None,
                  ):
        """
        :param input_ids: tokenized text from CLIPProcessor.
        :param attention_mask: attention mask from CLIPProcessor.
        :return: text embeddings of input_ids (tokens longer then 77 tokens
        is embeded using a sliding window and pooling).
        """
        tokens = []
        masks = []
        pos = []

        for i in range(input_ids.size()[0]):
            ten = input_ids[i]
            mask = attention_mask[i]
            mask = mask[mask.nonzero().flatten()]
            ten = ten[:mask.size()[0]]

            if not pos:
                pos.append([0, 0])
            else:
                pos.append([pos[-1][1], pos[-1][1]])

            #spliting tokenized text into input sized chunks with an overlapping window.
            if ten.size()[0]>77:
                tokens.append(ten.unfold(dimension = 0,size = 77, step = 70))
                masks.append(mask.unfold(dimension = 0,size = 77, step = 70))

                pos[-1][1]+=tokens[-1].size()[0]

                ten=ten[tokens[-1].size()[0]*70:]
                mask=mask[tokens[-1].size()[0]*70:]

            if ten.size()[0] > 0:
                new_mask = torch.zeros((1, 77)).to(self.device)
                new_mask[:, 0:mask.size()[0]] = mask

                new_ten = torch.full((1, 77), 49407).to(self.device)
                new_ten[:, 0:ten.size()[0]] = ten

                tokens.append(new_ten)
                masks.append(new_mask)
                pos[-1][1] += 1
        #encoding the tokenized text
        embedded = self.get_text_features(input_ids=torch.cat(tokens, 0),
                                          attention_mask=torch.cat(masks, 0),
                                         output_attentions=output_attentions,
                                         output_hidden_states=output_hidden_states,
                                         position_ids=position_ids,
                                         )
        
        #pooling the embeddings of segments that came from the same original text
        embeddings = []
        for p in pos:
            if p[1] - p[0] == 1:
                embeddings.append(embedded[p[0]].unsqueeze(0))
            else:
                embeddings.append(torch.mean(embedded[p[0]:p[1]], dim=0).unsqueeze(0))

        return torch.cat(embeddings, 0)

    def forward(self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Tuple:
        """
        :param input_ids: tokenized text from CLIPProcessor.
        :param attention_mask: attention mask from CLIPProcessor.
        :param pixel_values: pixel values from CLIPProcessor.
        :param return_loss: boolean that indicates if loss should be returned
        :return: image-caption cosine similarity as logits per image and per caption (also loss if return_loss is true)
        """
        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = self.config.use_return_dict
        
        #encoding the images 
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        image_embeds = vision_outputs[1]
        image_embeds = self.visual_projection(image_embeds)

        #encoding the text captions
        text_embeds =self.embed_text(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     output_attentions=output_attentions,
                                     output_hidden_states=output_hidden_states,
                                     position_ids=position_ids
        )


        # normalized features
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        logits_per_image = logits_per_text.T

        if return_loss:
            loss = clip_loss(logits_per_text)
            return logits_per_image,logits_per_text,loss
        return logits_per_image,logits_per_text