Backpack-Demo / senses /use_senses.py
johnhew's picture
Duplicate from lora-x/Backpack
ffb38f8
"""Visualize some sense vectors"""
import torch
import argparse
import transformers
def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None):
"""
Prints out the top-scoring words (and lowest-scoring words) for each sense.
"""
if contents is None:
print(word)
token_id = tokenizer(word)['input_ids'][0]
contents = vecs[token_id] # torch.Size([16, 768])
for i in range(contents.shape[0]):
print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i))
logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
print('~~~Positive~~~')
for j in range(count):
print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item()))
print('~~~Negative~~~')
for j in range(count):
print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item()))
return contents
print()
print()
print()
argp = argparse.ArgumentParser()
argp.add_argument('vecs_path')
argp.add_argument('lm_head_path')
args = argp.parse_args()
# Load tokenizer and parameters
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
vecs = torch.load(args.vecs_path)
lm_head = torch.load(args.lm_head_path)
visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)