azbert-base / test.py
w32zhong's picture
initial commit.
107cd34
raw history blame
No virus
2.32 kB
import re
import os
import fire
import torch
from functools import partial
from transformers import BertTokenizer
from transformers import BertForPreTraining
from pya0.preprocess import preprocess_for_transformer
def highlight_masked(txt):
return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt)
def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
unmask_scores, seq_rel_scores = outputs
MSK_CODE = 103
token_ids = tokens['input_ids'][0]
masked_idx = (token_ids == torch.tensor([MSK_CODE]))
scores = unmask_scores[0][masked_idx]
cands = torch.argsort(scores, dim=1, descending=True)
for i, mask_cands in enumerate(cands):
top_cands = mask_cands[:topk].detach().cpu()
print(f'MASK[{i}] top candidates: ' +
str(tokenizer.convert_ids_to_tokens(top_cands)))
def test(
test_file='test.txt',
ckpt_bert='ckpt/bert-pretrained-for-math-7ep/6_3_1382',
ckpt_tokenizer='ckpt/bert-tokenizer-for-math'
):
tokenizer = BertTokenizer.from_pretrained(ckpt_tokenizer)
model = BertForPreTraining.from_pretrained(ckpt_bert,
tie_word_embeddings=True
)
with open(test_file, 'r') as fh:
for line in fh:
# parse test file line
line = line.rstrip()
fields = line.split('\t')
maskpos = list(map(int, fields[0].split(',')))
# preprocess and mask words
sentence = preprocess_for_transformer(fields[1])
tokens = sentence.split()
for pos in filter(lambda x: x!=0, maskpos):
tokens[pos-1] = '[MASK]'
sentence = ' '.join(tokens)
tokens = tokenizer(sentence,
padding=True, truncation=True, return_tensors="pt")
#print(tokenizer.decode(tokens['input_ids'][0]))
print('*', highlight_masked(sentence))
# print unmasked
with torch.no_grad():
display = ['\n', '']
classifier = model.cls
partial_hook = partial(classifier_hook, tokenizer, tokens, 3)
hook = classifier.register_forward_hook(partial_hook)
model(**tokens)
hook.remove()
if __name__ == '__main__':
os.environ["PAGER"] = 'cat'
fire.Fire(test)