File size: 1,391 Bytes
ffb38f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Visualize some sense vectors"""

import torch
import argparse

import transformers

def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None):
  """
  Prints out the top-scoring words (and lowest-scoring words) for each sense.

  """
  if contents is None:
    print(word)
    token_id = tokenizer(word)['input_ids'][0]
    contents = vecs[token_id] # torch.Size([16, 768])

  for i in range(contents.shape[0]):
    print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i))
    logits = contents[i,:] @ lm_head.t() # (vocab,)    [768] @ [768, 50257] -> [50257]
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    print('~~~Positive~~~')
    for j in range(count):
      print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item()))
    print('~~~Negative~~~')
    for j in range(count):
      print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item()))
  return contents
  print()
  print()
  print()

argp = argparse.ArgumentParser()
argp.add_argument('vecs_path')
argp.add_argument('lm_head_path')
args = argp.parse_args()

# Load tokenizer and parameters
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
vecs = torch.load(args.vecs_path)
lm_head = torch.load(args.lm_head_path)

visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)