| import re |
| import os |
| import fire |
| import torch |
| from functools import partial |
| from transformers import AutoTokenizer |
| from transformers import AutoModelForPreTraining |
| from pya0.preprocess import preprocess_for_transformer |
|
|
|
|
| def highlight_masked(txt): |
| return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt) |
|
|
|
|
| def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs): |
| unmask_scores, seq_rel_scores = outputs |
| MSK_CODE = 103 |
| token_ids = tokens['input_ids'][0] |
| masked_idx = (token_ids == torch.tensor([MSK_CODE])) |
| scores = unmask_scores[0][masked_idx] |
| cands = torch.argsort(scores, dim=1, descending=True) |
| for i, mask_cands in enumerate(cands): |
| top_cands = mask_cands[:topk].detach().cpu() |
| print(f'MASK[{i}] top candidates: ' + |
| str(tokenizer.convert_ids_to_tokens(top_cands))) |
|
|
|
|
| def test(tokenizer_name_or_path, model_name_or_path, test_file='test.txt'): |
|
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) |
| model = AutoModelForPreTraining.from_pretrained(model_name_or_path, |
| tie_word_embeddings=True |
| ) |
| with open(test_file, 'r') as fh: |
| for line in fh: |
| |
| line = line.rstrip() |
| fields = line.split('\t') |
| maskpos = list(map(int, fields[0].split(','))) |
| |
| sentence = preprocess_for_transformer(fields[1]) |
| tokens = sentence.split() |
| for pos in filter(lambda x: x!=0, maskpos): |
| tokens[pos-1] = '[MASK]' |
| sentence = ' '.join(tokens) |
| sentence = sentence.replace('[mask]', '[MASK]') |
| tokens = tokenizer(sentence, |
| padding=True, truncation=True, return_tensors="pt") |
| |
| print('*', highlight_masked(sentence)) |
| |
| with torch.no_grad(): |
| display = ['\n', ''] |
| classifier = model.cls |
| partial_hook = partial(classifier_hook, tokenizer, tokens, 3) |
| hook = classifier.register_forward_hook(partial_hook) |
| model(**tokens) |
| hook.remove() |
|
|
|
|
| if __name__ == '__main__': |
| os.environ["PAGER"] = 'cat' |
| fire.Fire(test) |
|
|