File size: 1,695 Bytes
805578d
 
9873494
3c592ef
805578d
3c592ef
805578d
 
 
3c3ce4a
85fb431
3756a49
3c3ce4a
 
 
 
3da2e1f
3c3ce4a
 
3da2e1f
a23785a
3c3ce4a
 
 
 
 
a23785a
 
422ee62
a23785a
408c74f
a23785a
3c3ce4a
 
 
 
 
 
a23785a
 
 
3c3ce4a
422ee62
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
This model was pretrained on the bookcorpus dataset using knowledge distillation.

The particularity of this model is that even though it shares the same architecture as BERT, it has a hidden size of 384 (half the hidden size of BERT) and 6 attention heads (hence the same head size of BERT).

The knowledge distillation was performed using multiple loss functions.

The weights of the model were initialized from scratch.

PS : the tokenizer is the same as the one of the model bert-base-uncased.

** PS2 : I am currently fixing a bug on this model. Do not expect anything from this model until my next update. **


To load the model \& tokenizer :

````python
from transformers import AutoModelForMaskedLM, BertTokenizer

model_name = "eli4s/Bert-L12-h384-A6"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
````

To use it on a sentence :

````python
import torch

sentence = "Let's have a [MASK]."

model.eval()
encoded_inputs = tokenizer([sentence], padding='longest')
input_ids = torch.tensor(encoded_inputs['input_ids'])
attention_mask = torch.tensor(encoded_inputs['attention_mask'])
output = model(input_ids, attention_mask=attention_mask)

mask_index = input_ids.tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)

print(predicted_token)
````

Or we can also predict the n most relevant predictions :

````python
top_n = 5

vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda  i:logits[i], reverse=True)[:top_n]

tokenizer.decode(top_tokens)
````