calpt commited on
Commit
9ca9564
1 Parent(s): e002a0e

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +6 -3
  2. modeling_clip.py +129 -0
config.json CHANGED
@@ -4,6 +4,9 @@
4
  "architectures": [
5
  "OpenCLIPVisionTextDualEncoderModel"
6
  ],
 
 
 
7
  "logit_scale_init_value": 2.6592,
8
  "model_type": "vision-text-dual-encoder",
9
  "projection_dim": 1024,
@@ -82,7 +85,7 @@
82
  "top_p": 1.0,
83
  "torch_dtype": null,
84
  "torchscript": false,
85
- "transformers_version": "4.27.0.dev0",
86
  "type_vocab_size": 1,
87
  "typical_p": 1.0,
88
  "use_bfloat16": false,
@@ -104,6 +107,7 @@
104
  "decoder_start_token_id": null,
105
  "diversity_penalty": 0.0,
106
  "do_sample": false,
 
107
  "early_stopping": false,
108
  "encoder_no_repeat_ngram_size": 0,
109
  "eos_token_id": null,
@@ -146,7 +150,6 @@
146
  "patch_size": 14,
147
  "prefix": null,
148
  "problem_type": null,
149
- "projection_dim": 512,
150
  "pruned_heads": {},
151
  "remove_invalid_values": false,
152
  "repetition_penalty": 1.0,
@@ -164,7 +167,7 @@
164
  "top_p": 1.0,
165
  "torch_dtype": null,
166
  "torchscript": false,
167
- "transformers_version": "4.27.0.dev0",
168
  "typical_p": 1.0,
169
  "use_bfloat16": false
170
  }
 
4
  "architectures": [
5
  "OpenCLIPVisionTextDualEncoderModel"
6
  ],
7
+ "auto_map": {
8
+ "AutoModel": "modeling_clip.OpenCLIPVisionTextDualEncoderModel"
9
+ },
10
  "logit_scale_init_value": 2.6592,
11
  "model_type": "vision-text-dual-encoder",
12
  "projection_dim": 1024,
 
85
  "top_p": 1.0,
86
  "torch_dtype": null,
87
  "torchscript": false,
88
+ "transformers_version": "4.24.0",
89
  "type_vocab_size": 1,
90
  "typical_p": 1.0,
91
  "use_bfloat16": false,
 
107
  "decoder_start_token_id": null,
108
  "diversity_penalty": 0.0,
109
  "do_sample": false,
110
+ "dropout": 0.0,
111
  "early_stopping": false,
112
  "encoder_no_repeat_ngram_size": 0,
113
  "eos_token_id": null,
 
150
  "patch_size": 14,
151
  "prefix": null,
152
  "problem_type": null,
 
153
  "pruned_heads": {},
154
  "remove_invalid_values": false,
155
  "repetition_penalty": 1.0,
 
167
  "top_p": 1.0,
168
  "torch_dtype": null,
169
  "torchscript": false,
170
+ "transformers_version": "4.24.0",
171
  "typical_p": 1.0,
172
  "use_bfloat16": false
173
  }
modeling_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
6
+ from transformers.models.vision_text_dual_encoder.modeling_vision_text_dual_encoder import clip_loss, CLIPOutput
7
+
8
+
9
+ class MeanPooler(nn.Module):
10
+ """Mean pooling"""
11
+
12
+ def forward(self, x, attention_mask):
13
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
14
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
15
+
16
+
17
+ class OpenCLIPVisionTextDualEncoderModel(VisionTextDualEncoderModel):
18
+ def __init__(
19
+ self,
20
+ config: Optional[VisionTextDualEncoderConfig] = None,
21
+ vision_model: Optional[PreTrainedModel] = None,
22
+ text_model: Optional[PreTrainedModel] = None,
23
+ add_text_model_pooling_layer: bool = False,
24
+ ):
25
+ super().__init__(config, vision_model, text_model)
26
+
27
+ # Remove text pooling layer
28
+ if not add_text_model_pooling_layer:
29
+ self.text_model.pooler = None
30
+
31
+ # Add mean pooling
32
+ self.pooler = MeanPooler()
33
+ # Overwrite text projection
34
+ hidden_size = (self.text_embed_dim + self.projection_dim) // 2
35
+ self.text_projection = nn.Sequential(
36
+ nn.Linear(self.text_embed_dim, hidden_size, bias=False),
37
+ nn.GELU(),
38
+ nn.Linear(hidden_size, self.projection_dim, bias=False),
39
+ )
40
+
41
+ def get_text_features(
42
+ self,
43
+ input_ids=None,
44
+ attention_mask=None,
45
+ position_ids=None,
46
+ token_type_ids=None,
47
+ output_attentions=None,
48
+ output_hidden_states=None,
49
+ return_dict=None,
50
+ ):
51
+ text_outputs = self.text_model(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
+ position_ids=position_ids,
55
+ token_type_ids=token_type_ids,
56
+ output_attentions=output_attentions,
57
+ output_hidden_states=output_hidden_states,
58
+ return_dict=return_dict,
59
+ )
60
+
61
+ pooled_output = self.pooler(text_outputs, attention_mask)
62
+ text_features = self.text_projection(pooled_output)
63
+
64
+ return text_features
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.LongTensor] = None,
69
+ pixel_values: Optional[torch.FloatTensor] = None,
70
+ attention_mask: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.LongTensor] = None,
72
+ return_loss: Optional[bool] = None,
73
+ token_type_ids: Optional[torch.LongTensor] = None,
74
+ output_attentions: Optional[bool] = None,
75
+ output_hidden_states: Optional[bool] = None,
76
+ return_dict: Optional[bool] = None,
77
+ ) -> Union[Tuple[torch.Tensor], CLIPOutput]:
78
+
79
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
80
+
81
+ vision_outputs = self.vision_model(
82
+ pixel_values=pixel_values,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+
88
+ text_outputs = self.text_model(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ image_embeds = vision_outputs[1] # pooler_output
99
+ image_embeds = self.visual_projection(image_embeds)
100
+
101
+ pooled_output = self.pooler(text_outputs, attention_mask)
102
+ text_embeds = self.text_projection(pooled_output)
103
+
104
+ # normalized features
105
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
106
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
107
+
108
+ # cosine similarity as logits
109
+ logit_scale = self.logit_scale.exp()
110
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
111
+ logits_per_image = logits_per_text.T
112
+
113
+ loss = None
114
+ if return_loss:
115
+ loss = clip_loss(logits_per_text)
116
+
117
+ if not return_dict:
118
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
119
+ return ((loss,) + output) if loss is not None else output
120
+
121
+ return CLIPOutput(
122
+ loss=loss,
123
+ logits_per_image=logits_per_image,
124
+ logits_per_text=logits_per_text,
125
+ text_embeds=text_embeds,
126
+ image_embeds=image_embeds,
127
+ text_model_output=text_outputs,
128
+ vision_model_output=vision_outputs,
129
+ )