marcusinthesky commited on
Commit
924daa7
1 Parent(s): 691e117

Upload model

Browse files
Files changed (3) hide show
  1. config.json +179 -0
  2. model.safetensors +3 -0
  3. modelling.py +174 -0
config.json ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "e9754cb57664705bacb62145ea8a977f269a456b",
3
+ "_name_or_path": "flavour/vtde-dinov2-small-bge-small-en",
4
+ "architectures": [
5
+ "VTDEModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "modelling.VTDEConfig",
9
+ "AutoModel": "modelling.VTDEModel"
10
+ },
11
+ "logit_scale_init_value": 2.6592,
12
+ "model_type": "vtde",
13
+ "projection_dim": 384,
14
+ "text_config": {
15
+ "_name_or_path": "BAAI/bge-small-en",
16
+ "add_cross_attention": false,
17
+ "architectures": [
18
+ "BertModel"
19
+ ],
20
+ "attention_probs_dropout_prob": 0.1,
21
+ "bad_words_ids": null,
22
+ "begin_suppress_tokens": null,
23
+ "bos_token_id": null,
24
+ "chunk_size_feed_forward": 0,
25
+ "classifier_dropout": null,
26
+ "cross_attention_hidden_size": null,
27
+ "decoder_start_token_id": null,
28
+ "diversity_penalty": 0.0,
29
+ "do_sample": false,
30
+ "early_stopping": false,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "eos_token_id": null,
33
+ "exponential_decay_length_penalty": null,
34
+ "finetuning_task": null,
35
+ "forced_bos_token_id": null,
36
+ "forced_eos_token_id": null,
37
+ "hidden_act": "gelu",
38
+ "hidden_dropout_prob": 0.1,
39
+ "hidden_size": 384,
40
+ "id2label": {
41
+ "0": "LABEL_0"
42
+ },
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 1536,
45
+ "is_decoder": false,
46
+ "is_encoder_decoder": false,
47
+ "label2id": {
48
+ "LABEL_0": 0
49
+ },
50
+ "layer_norm_eps": 1e-12,
51
+ "length_penalty": 1.0,
52
+ "max_length": 20,
53
+ "max_position_embeddings": 512,
54
+ "min_length": 0,
55
+ "model_type": "bert",
56
+ "no_repeat_ngram_size": 0,
57
+ "num_attention_heads": 12,
58
+ "num_beam_groups": 1,
59
+ "num_beams": 1,
60
+ "num_hidden_layers": 12,
61
+ "num_return_sequences": 1,
62
+ "output_attentions": false,
63
+ "output_hidden_states": false,
64
+ "output_scores": false,
65
+ "pad_token_id": 0,
66
+ "position_embedding_type": "absolute",
67
+ "prefix": null,
68
+ "problem_type": null,
69
+ "pruned_heads": {},
70
+ "remove_invalid_values": false,
71
+ "repetition_penalty": 1.0,
72
+ "return_dict": true,
73
+ "return_dict_in_generate": false,
74
+ "sep_token_id": null,
75
+ "suppress_tokens": null,
76
+ "task_specific_params": null,
77
+ "temperature": 1.0,
78
+ "tf_legacy_loss": false,
79
+ "tie_encoder_decoder": false,
80
+ "tie_word_embeddings": true,
81
+ "tokenizer_class": null,
82
+ "top_k": 50,
83
+ "top_p": 1.0,
84
+ "torch_dtype": "float32",
85
+ "torchscript": false,
86
+ "transformers_version": "4.32.0.dev0",
87
+ "type_vocab_size": 2,
88
+ "typical_p": 1.0,
89
+ "use_bfloat16": false,
90
+ "use_cache": true,
91
+ "vocab_size": 30522
92
+ },
93
+ "text_pooling_mode": "mean",
94
+ "torch_dtype": "float32",
95
+ "transformers_version": null,
96
+ "vision_config": {
97
+ "_name_or_path": "facebook/dinov2-small",
98
+ "add_cross_attention": false,
99
+ "architectures": [
100
+ "Dinov2Model"
101
+ ],
102
+ "attention_probs_dropout_prob": 0.0,
103
+ "bad_words_ids": null,
104
+ "begin_suppress_tokens": null,
105
+ "bos_token_id": null,
106
+ "chunk_size_feed_forward": 0,
107
+ "cross_attention_hidden_size": null,
108
+ "decoder_start_token_id": null,
109
+ "diversity_penalty": 0.0,
110
+ "do_sample": false,
111
+ "drop_path_rate": 0.0,
112
+ "early_stopping": false,
113
+ "encoder_no_repeat_ngram_size": 0,
114
+ "eos_token_id": null,
115
+ "exponential_decay_length_penalty": null,
116
+ "finetuning_task": null,
117
+ "forced_bos_token_id": null,
118
+ "forced_eos_token_id": null,
119
+ "hidden_act": "gelu",
120
+ "hidden_dropout_prob": 0.0,
121
+ "hidden_size": 384,
122
+ "id2label": {
123
+ "0": "LABEL_0",
124
+ "1": "LABEL_1"
125
+ },
126
+ "image_size": 518,
127
+ "initializer_range": 0.02,
128
+ "is_decoder": false,
129
+ "is_encoder_decoder": false,
130
+ "label2id": {
131
+ "LABEL_0": 0,
132
+ "LABEL_1": 1
133
+ },
134
+ "layer_norm_eps": 1e-06,
135
+ "layerscale_value": 1.0,
136
+ "length_penalty": 1.0,
137
+ "max_length": 20,
138
+ "min_length": 0,
139
+ "mlp_ratio": 4,
140
+ "model_type": "dinov2",
141
+ "no_repeat_ngram_size": 0,
142
+ "num_attention_heads": 6,
143
+ "num_beam_groups": 1,
144
+ "num_beams": 1,
145
+ "num_channels": 3,
146
+ "num_hidden_layers": 12,
147
+ "num_return_sequences": 1,
148
+ "output_attentions": false,
149
+ "output_hidden_states": false,
150
+ "output_scores": false,
151
+ "pad_token_id": null,
152
+ "patch_size": 14,
153
+ "prefix": null,
154
+ "problem_type": null,
155
+ "pruned_heads": {},
156
+ "qkv_bias": true,
157
+ "remove_invalid_values": false,
158
+ "repetition_penalty": 1.0,
159
+ "return_dict": true,
160
+ "return_dict_in_generate": false,
161
+ "sep_token_id": null,
162
+ "suppress_tokens": null,
163
+ "task_specific_params": null,
164
+ "temperature": 1.0,
165
+ "tf_legacy_loss": false,
166
+ "tie_encoder_decoder": false,
167
+ "tie_word_embeddings": true,
168
+ "tokenizer_class": null,
169
+ "top_k": 50,
170
+ "top_p": 1.0,
171
+ "torch_dtype": "float32",
172
+ "torchscript": false,
173
+ "transformers_version": "4.32.0.dev0",
174
+ "typical_p": 1.0,
175
+ "use_bfloat16": false,
176
+ "use_swiglu_ffn": false
177
+ },
178
+ "vision_pooling_mode": "max"
179
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8c1a74eb8ad11a602187aec925f1a8efadeadd307c78ddb22afc032aa2cf508
3
+ size 223489116
modelling.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/12_modelling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['VTDEConfig', 'VTDEModel']
5
+
6
+ # %% ../notebooks/12_modelling.ipynb 1
7
+ from transformers.models.clip.modeling_clip import CLIPOutput, clip_loss
8
+ from typing import Optional, Tuple, Union
9
+ from transformers import VisionTextDualEncoderConfig, AutoModel, PreTrainedModel, VisionTextDualEncoderModel
10
+ import torch
11
+
12
+ class VTDEConfig(VisionTextDualEncoderConfig):
13
+ model_type = "vtde"
14
+
15
+ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592,
16
+ text_pooling_mode='mean',
17
+ vision_pooling_mode='max',
18
+ **kwargs):
19
+ """
20
+ pooling_mode in ['mean', 'max', 'cls']
21
+ https://arxiv.org/pdf/2210.09996.pdf
22
+ https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56
23
+ """
24
+ self.text_pooling_mode = text_pooling_mode
25
+ self.vision_pooling_mode = vision_pooling_mode
26
+ super().__init__(projection_dim, logit_scale_init_value, **kwargs)
27
+
28
+ class VTDEModel(VisionTextDualEncoderModel):
29
+ config_class = VTDEConfig
30
+ base_model_prefix = "vtde"
31
+
32
+ def __init__(
33
+ self,
34
+ config: Optional[VTDEConfig] = None,
35
+ vision_model: Optional[PreTrainedModel] = None,
36
+ text_model: Optional[PreTrainedModel] = None,
37
+ ):
38
+ # You can customize the constructor if needed
39
+ super().__init__(config, vision_model, text_model)
40
+ self.text_pooling_mode = config.text_pooling_mode
41
+ self.vision_pooling_mode = config.vision_pooling_mode
42
+
43
+ def get_text_features(
44
+ self,
45
+ input_ids=None,
46
+ attention_mask=None,
47
+ position_ids=None,
48
+ token_type_ids=None,
49
+ output_attentions=None,
50
+ output_hidden_states=None,
51
+ return_dict=None,
52
+ ):
53
+ text_outputs = self.text_model(
54
+ input_ids=input_ids,
55
+ attention_mask=attention_mask,
56
+ token_type_ids=token_type_ids,
57
+ position_ids=position_ids,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ return_dict=return_dict,
61
+ )
62
+ if self.text_pooling_mode == 'cls':
63
+ pooled_output = text_outputs[1]
64
+ elif self.text_pooling_mode == 'mean':
65
+ pooled_output = torch.mean(text_outputs[0], dim=1)
66
+ elif self.text_pooling_mode == 'max':
67
+ pooled_output = torch.max(text_outputs[0], dim=1)[0]
68
+ elif self.text_pooling_mode == 'norm':
69
+ """we select the patch with the largest norm"""
70
+ last_hidden_states = text_outputs[0]
71
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
72
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
73
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
74
+ else:
75
+ "We want to raise the name of the pooling mode"
76
+ raise NotImplementedError
77
+
78
+ text_features = self.text_projection(pooled_output)
79
+
80
+ return text_features
81
+
82
+ def get_image_features(
83
+ self,
84
+ pixel_values=None,
85
+ output_attentions=None,
86
+ output_hidden_states=None,
87
+ return_dict=None,
88
+ ):
89
+ vision_outputs = self.vision_model(
90
+ pixel_values=pixel_values,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ if self.vision_pooling_mode == 'cls':
97
+ pooled_output = vision_outputs[1]
98
+ elif self.vision_pooling_mode == 'mean':
99
+ pooled_output = torch.mean(vision_outputs[0], dim=1)
100
+ elif self.vision_pooling_mode == 'max':
101
+ pooled_output = torch.max(vision_outputs[0], dim=1)[0]
102
+ elif self.vision_pooling_mode == 'norm':
103
+ """we select the patch with the largest norm"""
104
+ last_hidden_states = vision_outputs[0]
105
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
106
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
107
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
108
+ else:
109
+ raise NotImplementedError
110
+
111
+ image_features = self.visual_projection(pooled_output)
112
+
113
+ return image_features
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: Optional[torch.LongTensor] = None,
118
+ pixel_values: Optional[torch.FloatTensor] = None,
119
+ attention_mask: Optional[torch.Tensor] = None,
120
+ position_ids: Optional[torch.LongTensor] = None,
121
+ return_loss: Optional[bool] = None,
122
+ token_type_ids: Optional[torch.LongTensor] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ) -> Union[Tuple[torch.Tensor], CLIPOutput]:
127
+
128
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
129
+
130
+ image_embeds = self.get_image_features(
131
+ pixel_values=pixel_values,
132
+ output_attentions=output_attentions,
133
+ output_hidden_states=output_hidden_states,
134
+ return_dict=return_dict,
135
+ )
136
+
137
+ text_embeds = self.get_text_features(
138
+ input_ids=input_ids,
139
+ attention_mask=attention_mask,
140
+ position_ids=position_ids,
141
+ output_attentions=output_attentions,
142
+ output_hidden_states=output_hidden_states,
143
+ return_dict=return_dict,
144
+ )
145
+
146
+ # normalized features
147
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
148
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
149
+
150
+ # cosine similarity as logits
151
+ logit_scale = self.logit_scale.exp()
152
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
153
+ logits_per_image = logits_per_text.T
154
+
155
+ loss = None
156
+ if return_loss:
157
+ loss = clip_loss(logits_per_text)
158
+
159
+ if not return_dict:
160
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
161
+ return ((loss,) + output) if loss is not None else output
162
+
163
+ return CLIPOutput(
164
+ loss=loss,
165
+ logits_per_image=logits_per_image,
166
+ logits_per_text=logits_per_text,
167
+ text_embeds=text_embeds,
168
+ image_embeds=image_embeds,
169
+ text_model_output=text_embeds,
170
+ vision_model_output=image_embeds,
171
+ )
172
+
173
+ VTDEConfig.register_for_auto_class()
174
+ VTDEModel.register_for_auto_class("AutoModel")