qilowoq commited on
Commit
4927150
1 Parent(s): 008d612

Update AbLang_bert_model.py

Browse files
Files changed (1) hide show
  1. AbLang_bert_model.py +2 -80
AbLang_bert_model.py CHANGED
@@ -26,87 +26,9 @@ class BertEmbeddingsV2(BertEmbeddings):
26
  def create_position_ids_from_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
27
  mask = input_ids.ne(self.pad_token_id).int()
28
  return torch.cumsum(mask, dim=1).long() * mask
29
-
30
-
31
- class BertModelV2(BertModel):
32
- def __init__(self, config):
33
- super().__init__(config)
34
- self.embeddings = BertEmbeddingsV2(config)
35
 
36
 
37
- class BertForMaskedLMV2(BertForMaskedLM):
38
  def __init__(self, config):
39
  super().__init__(config)
40
-
41
- def forward(
42
- self,
43
- input_ids: Optional[torch.Tensor] = None,
44
- attention_mask: Optional[torch.Tensor] = None,
45
- token_type_ids: Optional[torch.Tensor] = None,
46
- position_ids: Optional[torch.Tensor] = None,
47
- head_mask: Optional[torch.Tensor] = None,
48
- inputs_embeds: Optional[torch.Tensor] = None,
49
- encoder_hidden_states: Optional[torch.Tensor] = None,
50
- encoder_attention_mask: Optional[torch.Tensor] = None,
51
- labels: Optional[torch.Tensor] = None,
52
- output_attentions: Optional[bool] = None,
53
- output_hidden_states: Optional[bool] = None,
54
- return_dict: Optional[bool] = None,
55
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
56
- r"""
57
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
58
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
59
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
60
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
61
- """
62
-
63
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
-
65
- outputs = self.bert(
66
- input_ids,
67
- attention_mask=attention_mask,
68
- token_type_ids=token_type_ids,
69
- position_ids=position_ids,
70
- head_mask=head_mask,
71
- inputs_embeds=inputs_embeds,
72
- encoder_hidden_states=encoder_hidden_states,
73
- encoder_attention_mask=encoder_attention_mask,
74
- output_attentions=output_attentions,
75
- output_hidden_states=output_hidden_states,
76
- return_dict=return_dict,
77
- )
78
-
79
- sequence_output = outputs[0]
80
- prediction_scores = sequence_output[:, :, 0:24]
81
-
82
- masked_lm_loss = None
83
- if labels is not None:
84
- loss_fct = torch.nn.CrossEntropyLoss() # -100 index = padding token
85
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
86
-
87
- if not return_dict:
88
- output = (prediction_scores,) + outputs[2:]
89
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
90
-
91
- return MaskedLMOutput(
92
- loss=masked_lm_loss,
93
- logits=prediction_scores,
94
- hidden_states=outputs.hidden_states,
95
- attentions=outputs.attentions,
96
- )
97
-
98
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
99
- input_shape = input_ids.shape
100
- effective_batch_size = input_shape[0]
101
-
102
- # add a dummy token
103
- if self.config.pad_token_id is None:
104
- raise ValueError("The PAD token should be defined for generation")
105
-
106
- attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
107
- dummy_token = torch.full(
108
- (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
109
- )
110
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
111
-
112
- return {"input_ids": input_ids, "attention_mask": attention_mask}
 
26
  def create_position_ids_from_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
27
  mask = input_ids.ne(self.pad_token_id).int()
28
  return torch.cumsum(mask, dim=1).long() * mask
 
 
 
 
 
 
29
 
30
 
31
+ class BertModelV2(BertModel):
32
  def __init__(self, config):
33
  super().__init__(config)
34
+ self.embeddings = BertEmbeddingsV2(config)