qilowoq commited on
Commit
bac8bc3
1 Parent(s): 03f96d6

Upload model

Browse files
Files changed (3) hide show
  1. AbLang_bert_model.py +112 -0
  2. config.json +9 -7
  3. pytorch_model.bin +2 -2
AbLang_bert_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertEmbeddings, BertForMaskedLM, MaskedLMOutput
2
+ from transformers import BertModel
3
+ from typing import List, Optional, Tuple, Union
4
+ import torch
5
+
6
+ class BertEmbeddingsV2(BertEmbeddings):
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.pad_token_id = config.pad_token_id
10
+ self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0
11
+
12
+ def forward(
13
+ self,
14
+ input_ids: torch.LongTensor,
15
+ token_type_ids: Optional[torch.LongTensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ inputs_embeds: Optional[torch.FloatTensor] = None,
18
+ past_key_values_length: int = 0,
19
+ ) -> torch.Tensor:
20
+ inputs_embeds = self.word_embeddings(input_ids)
21
+ position_ids = self.create_position_ids_from_input_ids(input_ids)
22
+ position_embeddings = self.position_embeddings(position_ids)
23
+ embeddings = inputs_embeds + position_embeddings
24
+ return self.dropout(self.LayerNorm(embeddings))
25
+
26
+ def create_position_ids_from_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
27
+ mask = input_ids.ne(self.pad_token_id).int()
28
+ return torch.cumsum(mask, dim=1).long() * mask
29
+
30
+
31
+ class BertModelV2(BertModel):
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+ self.embeddings = BertEmbeddingsV2(config)
35
+
36
+
37
+ class BertForMaskedLMV2(BertForMaskedLM):
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: Optional[torch.Tensor] = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ token_type_ids: Optional[torch.Tensor] = None,
46
+ position_ids: Optional[torch.Tensor] = None,
47
+ head_mask: Optional[torch.Tensor] = None,
48
+ inputs_embeds: Optional[torch.Tensor] = None,
49
+ encoder_hidden_states: Optional[torch.Tensor] = None,
50
+ encoder_attention_mask: Optional[torch.Tensor] = None,
51
+ labels: Optional[torch.Tensor] = None,
52
+ output_attentions: Optional[bool] = None,
53
+ output_hidden_states: Optional[bool] = None,
54
+ return_dict: Optional[bool] = None,
55
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
56
+ r"""
57
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
58
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
59
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
60
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
61
+ """
62
+
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ outputs = self.bert(
66
+ input_ids,
67
+ attention_mask=attention_mask,
68
+ token_type_ids=token_type_ids,
69
+ position_ids=position_ids,
70
+ head_mask=head_mask,
71
+ inputs_embeds=inputs_embeds,
72
+ encoder_hidden_states=encoder_hidden_states,
73
+ encoder_attention_mask=encoder_attention_mask,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ return_dict=return_dict,
77
+ )
78
+
79
+ sequence_output = outputs[0]
80
+ prediction_scores = sequence_output[:, :, 0:24]
81
+
82
+ masked_lm_loss = None
83
+ if labels is not None:
84
+ loss_fct = torch.nn.CrossEntropyLoss() # -100 index = padding token
85
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
86
+
87
+ if not return_dict:
88
+ output = (prediction_scores,) + outputs[2:]
89
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
90
+
91
+ return MaskedLMOutput(
92
+ loss=masked_lm_loss,
93
+ logits=prediction_scores,
94
+ hidden_states=outputs.hidden_states,
95
+ attentions=outputs.attentions,
96
+ )
97
+
98
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
99
+ input_shape = input_ids.shape
100
+ effective_batch_size = input_shape[0]
101
+
102
+ # add a dummy token
103
+ if self.config.pad_token_id is None:
104
+ raise ValueError("The PAD token should be defined for generation")
105
+
106
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
107
+ dummy_token = torch.full(
108
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
109
+ )
110
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
111
+
112
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
config.json CHANGED
@@ -1,14 +1,13 @@
1
  {
2
- "_name_or_path": "ablang-test",
3
  "architectures": [
4
- "AbLang"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
- "AutoConfig": "config.AbLangConfig",
9
- "AutoModel": "model.AbLang"
10
  },
11
- "chain": "heavy",
12
  "hidden_act": "gelu",
13
  "hidden_dropout_prob": 0.1,
14
  "hidden_size": 768,
@@ -19,8 +18,11 @@
19
  "model_type": "bert",
20
  "num_attention_heads": 12,
21
  "num_hidden_layers": 12,
22
- "ptid": 21,
 
23
  "torch_dtype": "float32",
24
- "transformers_version": "4.26.1",
 
 
25
  "vocab_size": 24
26
  }
 
1
  {
2
+ "add_pooling_layer": false,
3
  "architectures": [
4
+ "BertModelV2"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
+ "AutoModel": "AbLang_bert_model.BertModelV2"
 
9
  },
10
+ "classifier_dropout": null,
11
  "hidden_act": "gelu",
12
  "hidden_dropout_prob": 0.1,
13
  "hidden_size": 768,
 
18
  "model_type": "bert",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 12,
21
+ "pad_token_id": 21,
22
+ "position_embedding_type": "absolute",
23
  "torch_dtype": "float32",
24
+ "transformers_version": "4.28.1",
25
+ "type_vocab_size": 2,
26
+ "use_cache": true,
27
  "vocab_size": 24
28
  }
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c526e50a72c294a36b8391ad52936a28832194061026cb1fe9e0455b07e7b5b6
3
- size 340855773
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9d5458446b8f723995df81e9b24b7a4635285fcb33d0d787a7e308bb16c75ea
3
+ size 343223341