qilowoq commited on
Commit
ce06371
1 Parent(s): 8a41c0a

Delete AbLang_bert_model.py

Browse files
Files changed (1) hide show
  1. AbLang_bert_model.py +0 -34
AbLang_bert_model.py DELETED
@@ -1,34 +0,0 @@
1
- from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertEmbeddings, BertForMaskedLM, MaskedLMOutput
2
- from transformers import BertModel
3
- from typing import List, Optional, Tuple, Union
4
- import torch
5
-
6
- class BertEmbeddingsV2(BertEmbeddings):
7
- def __init__(self, config):
8
- super().__init__(config)
9
- self.pad_token_id = config.pad_token_id
10
- self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0
11
-
12
- def forward(
13
- self,
14
- input_ids: torch.LongTensor,
15
- token_type_ids: Optional[torch.LongTensor] = None,
16
- position_ids: Optional[torch.LongTensor] = None,
17
- inputs_embeds: Optional[torch.FloatTensor] = None,
18
- past_key_values_length: int = 0,
19
- ) -> torch.Tensor:
20
- inputs_embeds = self.word_embeddings(input_ids)
21
- position_ids = self.create_position_ids_from_input_ids(input_ids)
22
- position_embeddings = self.position_embeddings(position_ids)
23
- embeddings = inputs_embeds + position_embeddings
24
- return self.dropout(self.LayerNorm(embeddings))
25
-
26
- def create_position_ids_from_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
27
- mask = input_ids.ne(self.pad_token_id).int()
28
- return torch.cumsum(mask, dim=1).long() * mask
29
-
30
-
31
- class BertModelV2(BertModel):
32
- def __init__(self, config):
33
- super().__init__(config)
34
- self.embeddings = BertEmbeddingsV2(config)