w32zhong commited on
Commit
107cd34
0 Parent(s):

initial commit.

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. README.md +20 -0
  3. test.py +66 -0
  4. test.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ckpt/
2
+ *.tar.gz
3
+ *.swp
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## About
2
+ Here we share a pretrained bert model that is aware of math tokens. The math tokens are treated specially and are tokenized using [pya0](https://github.com/approach0/pya0), which adds very limited new tokens for latex markup (total vocabulary is just 31061).
3
+
4
+ ### Usage
5
+ Download and try it out
6
+ ```sh
7
+ pip install pya0==0.3.2
8
+ wget https://vault.cs.uwaterloo.ca/s/gqstFZmWHCLGXe3/download -O ckpt.tar.gz
9
+ mkdir -p ckpt
10
+ tar xzf ckpt.tar.gz -C ckpt --strip-components=1
11
+ python test.py --test_file test.txt
12
+ ```
13
+
14
+ ### Test file format
15
+ Modify the test examples in `test.txt` to play with it.
16
+
17
+ The test file is tab separated, the first column is additional positions you want to mask for the right-side sentence (useful for masking tokens in math markups). An zero means no additional mask positions.
18
+
19
+ ### Example output
20
+ ![](https://i.imgur.com/xpl87KO.png)
test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import fire
4
+ import torch
5
+ from functools import partial
6
+ from transformers import BertTokenizer
7
+ from transformers import BertForPreTraining
8
+ from pya0.preprocess import preprocess_for_transformer
9
+
10
+
11
+ def highlight_masked(txt):
12
+ return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt)
13
+
14
+
15
+ def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
16
+ unmask_scores, seq_rel_scores = outputs
17
+ MSK_CODE = 103
18
+ token_ids = tokens['input_ids'][0]
19
+ masked_idx = (token_ids == torch.tensor([MSK_CODE]))
20
+ scores = unmask_scores[0][masked_idx]
21
+ cands = torch.argsort(scores, dim=1, descending=True)
22
+ for i, mask_cands in enumerate(cands):
23
+ top_cands = mask_cands[:topk].detach().cpu()
24
+ print(f'MASK[{i}] top candidates: ' +
25
+ str(tokenizer.convert_ids_to_tokens(top_cands)))
26
+
27
+
28
+ def test(
29
+ test_file='test.txt',
30
+ ckpt_bert='ckpt/bert-pretrained-for-math-7ep/6_3_1382',
31
+ ckpt_tokenizer='ckpt/bert-tokenizer-for-math'
32
+ ):
33
+
34
+ tokenizer = BertTokenizer.from_pretrained(ckpt_tokenizer)
35
+ model = BertForPreTraining.from_pretrained(ckpt_bert,
36
+ tie_word_embeddings=True
37
+ )
38
+ with open(test_file, 'r') as fh:
39
+ for line in fh:
40
+ # parse test file line
41
+ line = line.rstrip()
42
+ fields = line.split('\t')
43
+ maskpos = list(map(int, fields[0].split(',')))
44
+ # preprocess and mask words
45
+ sentence = preprocess_for_transformer(fields[1])
46
+ tokens = sentence.split()
47
+ for pos in filter(lambda x: x!=0, maskpos):
48
+ tokens[pos-1] = '[MASK]'
49
+ sentence = ' '.join(tokens)
50
+ tokens = tokenizer(sentence,
51
+ padding=True, truncation=True, return_tensors="pt")
52
+ #print(tokenizer.decode(tokens['input_ids'][0]))
53
+ print('*', highlight_masked(sentence))
54
+ # print unmasked
55
+ with torch.no_grad():
56
+ display = ['\n', '']
57
+ classifier = model.cls
58
+ partial_hook = partial(classifier_hook, tokenizer, tokens, 3)
59
+ hook = classifier.register_forward_hook(partial_hook)
60
+ model(**tokens)
61
+ hook.remove()
62
+
63
+
64
+ if __name__ == '__main__':
65
+ os.environ["PAGER"] = 'cat'
66
+ fire.Fire(test)
test.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 0 She needs to [MASK] that [MASK] gets only ten minutes.
2
+ 8 Determine the [MASK] of [imath]f(x) = x + \sqrt{ 4 - x^2}[/imath] without [MASK]
3
+ 4,12 Solve [imath]y''-4y'+4y=xe^x[/imath]
4
+ 4 [imath]f(x, y)[/imath]
5
+ 2 [imath]x + x = 2x[/imath]
6
+ 10,11 With Euler's [MASK], it [MASK] to [imath]\int_0^\infty \frac{1+x^2}{1+x}dx[/imath]
7
+ 6,12 Proof by [MASK] that [imath]n!>3n[/imath] [MASK] [imath]n>6[/imath]