Bingsu commited on
Commit
bfd9473
1 Parent(s): 8a88047

Update modeling_clip_masked_lm.py

Browse files
Files changed (1) hide show
  1. modeling_clip_masked_lm.py +16 -11
modeling_clip_masked_lm.py CHANGED
@@ -2,33 +2,38 @@ from typing import Optional, Tuple, Union
2
 
3
  import torch
4
  from torch import nn
5
- from transformers import CLIPTextConfig, CLIPTextModel
6
  from transformers.modeling_outputs import MaskedLMOutput
7
- from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
 
 
 
8
  from transformers.models.roberta.modeling_roberta import RobertaLMHead
9
 
10
 
11
  class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
12
  config_class = CLIPTextConfig
13
 
 
 
14
  def __init__(self, config: CLIPTextConfig):
15
  super().__init__(config)
16
- self.clip_text_model = CLIPTextModel(config)
17
  self.lm_head = RobertaLMHead(config)
18
 
19
  self.post_init()
20
 
21
- def get_input_embeddings(self):
22
- return self.clip_text_model.text_model.embeddings.token_embedding
23
 
24
- def set_input_embeddings(self, value):
25
- self.clip_text_model.text_model.embeddings.token_embedding = value
26
 
27
- def get_output_embeddings(self):
28
  return self.lm_head.decoder
29
 
30
- def set_output_embeddings(self, new_embeddings):
31
- self.lm_head.decoder = new_embeddings
32
 
33
  def forward(
34
  self,
@@ -44,7 +49,7 @@ class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
44
  return_dict if return_dict is not None else self.config.use_return_dict
45
  )
46
 
47
- outputs = self.clip_text_model(
48
  input_ids=input_ids,
49
  attention_mask=attention_mask,
50
  position_ids=position_ids,
 
2
 
3
  import torch
4
  from torch import nn
5
+ from transformers import CLIPTextConfig
6
  from transformers.modeling_outputs import MaskedLMOutput
7
+ from transformers.models.clip.modeling_clip import (
8
+ CLIPPreTrainedModel,
9
+ CLIPTextTransformer,
10
+ )
11
  from transformers.models.roberta.modeling_roberta import RobertaLMHead
12
 
13
 
14
  class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
15
  config_class = CLIPTextConfig
16
 
17
+ _no_split_modules = ["CLIPEncoderLayer"]
18
+
19
  def __init__(self, config: CLIPTextConfig):
20
  super().__init__(config)
21
+ self.text_model = CLIPTextTransformer(config)
22
  self.lm_head = RobertaLMHead(config)
23
 
24
  self.post_init()
25
 
26
+ def get_input_embeddings(self) -> nn.Module:
27
+ return self.text_model.embeddings.token_embedding
28
 
29
+ def set_input_embeddings(self, value: nn.Module) -> None:
30
+ self.text_model.embeddings.token_embedding = value
31
 
32
+ def get_output_embeddings(self) -> nn.Module:
33
  return self.lm_head.decoder
34
 
35
+ def set_output_embeddings(self, value: nn.Module) -> None:
36
+ self.lm_head.decoder = value
37
 
38
  def forward(
39
  self,
 
49
  return_dict if return_dict is not None else self.config.use_return_dict
50
  )
51
 
52
+ outputs = self.text_model(
53
  input_ids=input_ids,
54
  attention_mask=attention_mask,
55
  position_ids=position_ids,