imvladikon commited on
Commit
7177c13
โ€ข
1 Parent(s): 6e782ec

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +93 -0
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - he
4
+ tags:
5
+ - language model
6
+ pipeline_tag: feature-extraction
7
+ ---
8
+
9
+ ## AlephBertGimmel
10
+ Modern Hebrew pretrained BERT model with a 128K token vocabulary.
11
+
12
+
13
+ [Checkpoint](https://github.com/Dicta-Israel-Center-for-Text-Analysis/alephbertgimmel/tree/main/alephbertgimmel-small/ckpt_29400--Max128Seq) of the alephbertgimmel-small-128 from [alephbertgimmel](https://github.com/Dicta-Israel-Center-for-Text-Analysis/alephbertgimmel)
14
+
15
+
16
+ ```python
17
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
18
+
19
+
20
+ import torch
21
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
22
+
23
+ model = AutoModelForMaskedLM.from_pretrained("imvladikon/alephbertgimmel-small-128")
24
+ tokenizer = AutoTokenizer.from_pretrained("imvladikon/alephbertgimmel-small-128")
25
+
26
+ text = "{} ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”"
27
+
28
+ input = tokenizer.encode(text.format("[MASK]"), return_tensors="pt")
29
+ mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
30
+
31
+ token_logits = model(input).logits
32
+ mask_token_logits = token_logits[0, mask_token_index, :]
33
+ top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
34
+
35
+ for token in top_5_tokens:
36
+ print(text.format(tokenizer.decode([token])))
37
+
38
+ # ื™ืฉืจืืœ ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”
39
+ # ื™ืจื•ืฉืœื™ื ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”
40
+ # ื—ื™ืคื” ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”
41
+ # ืื™ืœืช ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”
42
+ # ืืฉื“ื•ื“ ื”ื™ื ืžื˜ืจื•ืคื•ืœื™ืŸ ื”ืžื”ื•ื•ื” ืืช ืžืจื›ื– ื”ื›ืœื›ืœื”
43
+ ```
44
+
45
+ ```python
46
+ def ppl_naive(text, model, tokenizer):
47
+ input = tokenizer.encode(text, return_tensors="pt")
48
+ loss = model(input, labels=input)[0]
49
+ return torch.exp(loss).item()
50
+
51
+ text = """{} ื”ื™ื ืขื™ืจ ื”ื‘ื™ืจื” ืฉืœ ืžื“ื™ื ืช ื™ืฉืจืืœ, ื•ื”ืขื™ืจ ื”ื’ื“ื•ืœื” ื‘ื™ื•ืชืจ ื‘ื™ืฉืจืืœ ื‘ื’ื•ื“ืœ ื”ืื•ื›ืœื•ืกื™ื™ื”"""
52
+
53
+ for word in ["ื—ื™ืคื”", "ื™ืจื•ืฉืœื™ื", "ืชืœ ืื‘ื™ื‘"]:
54
+ print(ppl_naive(text.format(word), model, tokenizer))
55
+
56
+ # 9.825098991394043
57
+ # 10.594215393066406
58
+ # 9.536449432373047
59
+
60
+ # I'd expect that for "ื™ืจื•ืฉืœื™ื" should be the smallest value, but...
61
+
62
+ @torch.inference_mode()
63
+ def ppl_pseudo(text, model, tokenizer, ignore_idx=-100):
64
+ input = tokenizer.encode(text, return_tensors='pt')
65
+ mask = torch.ones(input.size(-1) - 1).diag(1)[:-2]
66
+ repeat_input = input.repeat(input.size(-1) - 2, 1)
67
+ input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
68
+ labels = repeat_input.masked_fill(input != tokenizer.mask_token_id, ignore_idx)
69
+ loss = model(input, labels=labels)[0]
70
+ return torch.exp(loss).item()
71
+
72
+
73
+ for word in ["ื—ื™ืคื”", "ื™ืจื•ืฉืœื™ื", "ืชืœ ืื‘ื™ื‘"]:
74
+ print(ppl_pseudo(text.format(word), model, tokenizer))
75
+ # 4.346900939941406
76
+ # 3.292382001876831
77
+ # 2.732590913772583
78
+ ```
79
+
80
+ When using AlephBertGimmel, please reference:
81
+
82
+ ```bibtex
83
+
84
+ @misc{guetta2022large,
85
+ title={Large Pre-Trained Models with Extra-Large Vocabularies: A Contrastive Analysis of Hebrew BERT Models and a New One to Outperform Them All},
86
+ author={Eylon Guetta and Avi Shmidman and Shaltiel Shmidman and Cheyn Shmuel Shmidman and Joshua Guedalia and Moshe Koppel and Dan Bareket and Amit Seker and Reut Tsarfaty},
87
+ year={2022},
88
+ eprint={2211.15199},
89
+ archivePrefix={arXiv},
90
+ primaryClass={cs.CL}
91
+ }
92
+
93
+ ```