| | import re |
| | import os |
| | import fire |
| | import torch |
| | from functools import partial |
| | from transformers import AutoTokenizer |
| | from transformers import AutoModelForPreTraining |
| | from pya0.preprocess import preprocess_for_transformer |
| |
|
| |
|
| | def highlight_masked(txt): |
| | return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt) |
| |
|
| |
|
| | def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs): |
| | unmask_scores, seq_rel_scores = outputs |
| | MSK_CODE = 103 |
| | token_ids = tokens['input_ids'][0] |
| | masked_idx = (token_ids == torch.tensor([MSK_CODE])) |
| | scores = unmask_scores[0][masked_idx] |
| | cands = torch.argsort(scores, dim=1, descending=True) |
| | for i, mask_cands in enumerate(cands): |
| | top_cands = mask_cands[:topk].detach().cpu() |
| | print(f'MASK[{i}] top candidates: ' + |
| | str(tokenizer.convert_ids_to_tokens(top_cands))) |
| |
|
| |
|
| | def test(tokenizer_name_or_path, model_name_or_path, test_file='test.txt'): |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) |
| | model = AutoModelForPreTraining.from_pretrained(model_name_or_path, |
| | tie_word_embeddings=True |
| | ) |
| | with open(test_file, 'r') as fh: |
| | for line in fh: |
| | |
| | line = line.rstrip() |
| | fields = line.split('\t') |
| | maskpos = list(map(int, fields[0].split(','))) |
| | |
| | sentence = preprocess_for_transformer(fields[1]) |
| | tokens = sentence.split() |
| | for pos in filter(lambda x: x!=0, maskpos): |
| | tokens[pos-1] = '[MASK]' |
| | sentence = ' '.join(tokens) |
| | sentence = sentence.replace('[mask]', '[MASK]') |
| | tokens = tokenizer(sentence, |
| | padding=True, truncation=True, return_tensors="pt") |
| | |
| | print('*', highlight_masked(sentence)) |
| | |
| | with torch.no_grad(): |
| | display = ['\n', ''] |
| | classifier = model.cls |
| | partial_hook = partial(classifier_hook, tokenizer, tokens, 3) |
| | hook = classifier.register_forward_hook(partial_hook) |
| | model(**tokens) |
| | hook.remove() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | os.environ["PAGER"] = 'cat' |
| | fire.Fire(test) |
| |
|