SunderAli17 commited on
Commit
d6dcd92
1 Parent(s): 7b883ad

Create hf_model.py

Browse files
Files changed (1) hide show
  1. eva_clip/hf_model.py +247 -0
eva_clip/hf_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
3
+ """
4
+
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch import TensorType
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+ # utils
30
+ def _camel2snake(s):
31
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
32
+
33
+ # TODO: ?last - for gpt-like models
34
+ _POOLERS = {}
35
+
36
+ def register_pooler(cls):
37
+ """Decorator registering pooler class"""
38
+ _POOLERS[_camel2snake(cls.__name__)] = cls
39
+ return cls
40
+
41
+
42
+ @register_pooler
43
+ class MeanPooler(nn.Module):
44
+ """Mean pooling"""
45
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
46
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
47
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
48
+
49
+ @register_pooler
50
+ class MaxPooler(nn.Module):
51
+ """Max pooling"""
52
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
53
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
54
+ return masked_output.max(1).values
55
+
56
+ @register_pooler
57
+ class ClsPooler(nn.Module):
58
+ """CLS token pooling"""
59
+ def __init__(self, use_pooler_output=True):
60
+ super().__init__()
61
+ self.cls_token_position = 0
62
+ self.use_pooler_output = use_pooler_output
63
+
64
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
65
+
66
+ if (self.use_pooler_output and
67
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
68
+ (x.pooler_output is not None)
69
+ ):
70
+ return x.pooler_output
71
+
72
+ return x.last_hidden_state[:, self.cls_token_position, :]
73
+
74
+ class HFTextEncoder(nn.Module):
75
+ """HuggingFace model adapter"""
76
+ def __init__(
77
+ self,
78
+ model_name_or_path: str,
79
+ output_dim: int,
80
+ tokenizer_name: str = None,
81
+ config: PretrainedConfig = None,
82
+ pooler_type: str = None,
83
+ proj: str = None,
84
+ pretrained: bool = True,
85
+ masked_language_modeling: bool = False):
86
+ super().__init__()
87
+
88
+ self.output_dim = output_dim
89
+
90
+ # TODO: find better way to get this information
91
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
92
+
93
+ if transformers is None:
94
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
95
+ if config is None:
96
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
97
+ if masked_language_modeling:
98
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
99
+ AutoModelForMaskedLM.from_config, self.config)
100
+ else:
101
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
102
+ AutoModel.from_config, self.config)
103
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
104
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
105
+ self.transformer = create_func(model_args)
106
+ self.transformer = self.transformer.encoder
107
+ else:
108
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
109
+ else:
110
+ self.config = config
111
+ if masked_language_modeling:
112
+ self.transformer = AutoModelForMaskedLM.from_config(config)
113
+ else:
114
+ self.transformer = AutoModel.from_config(config)
115
+
116
+ if pooler_type is None: # get default arch pooler
117
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
118
+ else:
119
+ self.pooler = _POOLERS[pooler_type]()
120
+
121
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
122
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
123
+ self.proj = nn.Identity()
124
+ elif proj == 'linear':
125
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
126
+ elif proj == 'mlp':
127
+ hidden_size = (d_model + output_dim) // 2
128
+ self.proj = nn.Sequential(
129
+ nn.Linear(d_model, hidden_size, bias=False),
130
+ nn.GELU(),
131
+ nn.Linear(hidden_size, output_dim, bias=False),
132
+ )
133
+
134
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
135
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
136
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
137
+
138
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
139
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
140
+ # attn_mask = (x != self.config.pad_token_id).long()
141
+ # out = self.transformer(
142
+ # input_ids=x,
143
+ # attention_mask=attn_mask,
144
+ # encoder_hidden_states = image_embeds,
145
+ # encoder_attention_mask = image_atts,
146
+ # )
147
+ # pooled_out = self.pooler(out, attn_mask)
148
+
149
+ # return self.itm_proj(pooled_out)
150
+
151
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
152
+ if masked_indices is None:
153
+ masked_indices = torch.bernoulli(probability_matrix).bool()
154
+
155
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
156
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
157
+
158
+ if targets is not None:
159
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
160
+
161
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
162
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
163
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
164
+
165
+ # 10% of the time, we replace masked input tokens with random word
166
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
167
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
168
+ input_ids[indices_random] = random_words[indices_random]
169
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
170
+
171
+ if targets is not None:
172
+ return input_ids, targets
173
+ else:
174
+ return input_ids
175
+
176
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
177
+ labels = input_ids.clone()
178
+ attn_mask = (input_ids != self.config.pad_token_id).long()
179
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
180
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
181
+ probability_matrix = torch.full(labels.shape, mlm_probability)
182
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
183
+ probability_matrix = probability_matrix)
184
+ mlm_output = self.transformer(input_ids,
185
+ attention_mask = attn_mask,
186
+ encoder_hidden_states = image_embeds,
187
+ encoder_attention_mask = image_atts,
188
+ return_dict = True,
189
+ labels = labels,
190
+ )
191
+ return mlm_output.loss
192
+ # mlm_output = self.transformer(input_ids,
193
+ # attention_mask = attn_mask,
194
+ # encoder_hidden_states = image_embeds,
195
+ # encoder_attention_mask = image_atts,
196
+ # return_dict = True,
197
+ # ).last_hidden_state
198
+ # logits = self.mlm_proj(mlm_output)
199
+
200
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
201
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
202
+ # labels = labels[:, 1:].contiguous().view(-1)
203
+
204
+ # mlm_loss = F.cross_entropy(
205
+ # logits,
206
+ # labels,
207
+ # # label_smoothing=0.1,
208
+ # )
209
+ # return mlm_loss
210
+
211
+
212
+ def forward(self, x:TensorType) -> TensorType:
213
+ attn_mask = (x != self.config.pad_token_id).long()
214
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
215
+ pooled_out = self.pooler(out, attn_mask)
216
+
217
+ return self.proj(pooled_out)
218
+
219
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
220
+ if not unlocked_layers: # full freezing
221
+ for n, p in self.transformer.named_parameters():
222
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
223
+ return
224
+
225
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
226
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
227
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
228
+ embeddings = getattr(
229
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
230
+ modules = [embeddings, *layer_list][:-unlocked_layers]
231
+ # freeze layers
232
+ for module in modules:
233
+ for n, p in module.named_parameters():
234
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
235
+
236
+
237
+ @torch.jit.ignore
238
+ def set_grad_checkpointing(self, enable=True):
239
+ self.transformer.gradient_checkpointing_enable()
240
+
241
+ def get_num_layers(self):
242
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
243
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
244
+ return len(layer_list)
245
+
246
+ def init_parameters(self):
247
+ pass