"""Visualize some sense vectors""" import torch import argparse import transformers def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None): """ Prints out the top-scoring words (and lowest-scoring words) for each sense. """ if contents is None: print(word) token_id = tokenizer(word)['input_ids'][0] contents = vecs[token_id] # torch.Size([16, 768]) for i in range(contents.shape[0]): print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i)) logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257] sorted_logits, sorted_indices = torch.sort(logits, descending=True) print('~~~Positive~~~') for j in range(count): print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item())) print('~~~Negative~~~') for j in range(count): print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item())) return contents print() print() print() argp = argparse.ArgumentParser() argp.add_argument('vecs_path') argp.add_argument('lm_head_path') args = argp.parse_args() # Load tokenizer and parameters tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') vecs = torch.load(args.vecs_path) lm_head = torch.load(args.lm_head_path) visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)