Bingsu commited on
Commit
194b80a
1 Parent(s): 50a331b

Upload modeling_clip_masked_lm.py

Browse files
Files changed (1) hide show
  1. modeling_clip_masked_lm.py +75 -0
modeling_clip_masked_lm.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import CLIPTextConfig, CLIPTextModel
6
+ from transformers.modeling_outputs import MaskedLMOutput
7
+ from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
8
+ from transformers.models.roberta.modeling_roberta import RobertaLMHead
9
+
10
+
11
+ class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
12
+ config_class = CLIPTextConfig
13
+
14
+ def __init__(self, config: CLIPTextConfig):
15
+ super().__init__(config)
16
+ self.clip_text_model = CLIPTextModel(config)
17
+ self.lm_head = RobertaLMHead(config)
18
+
19
+ self.post_init()
20
+
21
+ def get_input_embeddings(self):
22
+ return self.clip_text_model.text_model.embeddings.token_embedding
23
+
24
+ def set_input_embeddings(self, value):
25
+ self.clip_text_model.text_model.embeddings.token_embedding = value
26
+
27
+ def get_output_embeddings(self):
28
+ return self.lm_head.decoder
29
+
30
+ def set_output_embeddings(self, new_embeddings):
31
+ self.lm_head.decoder = new_embeddings
32
+
33
+ def forward(
34
+ self,
35
+ input_ids: Optional[torch.Tensor] = None,
36
+ attention_mask: Optional[torch.Tensor] = None,
37
+ position_ids: Optional[torch.Tensor] = None,
38
+ labels: Optional[torch.Tensor] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ return_dict: Optional[bool] = None,
42
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
43
+ return_dict = (
44
+ return_dict if return_dict is not None else self.config.use_return_dict
45
+ )
46
+
47
+ outputs = self.clip_text_model(
48
+ input_ids=input_ids,
49
+ attention_mask=attention_mask,
50
+ position_ids=position_ids,
51
+ output_attentions=output_attentions,
52
+ output_hidden_states=output_hidden_states,
53
+ return_dict=return_dict,
54
+ )
55
+
56
+ sequence_output = outputs[0]
57
+ prediction_scores = self.lm_head(sequence_output)
58
+
59
+ mlm_loss = None
60
+ if labels is not None:
61
+ loss_fct = nn.CrossEntropyLoss()
62
+ mlm_loss = loss_fct(
63
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
64
+ )
65
+
66
+ if not return_dict:
67
+ output = (prediction_scores,) + outputs[2:]
68
+ return ((mlm_loss,) + output) if mlm_loss is not None else output
69
+
70
+ return MaskedLMOutput(
71
+ loss=mlm_loss,
72
+ logits=prediction_scores,
73
+ hidden_states=outputs.hidden_states,
74
+ attentions=outputs.attentions,
75
+ )