Lora commited on
Commit
97ff57c
1 Parent(s): 71ffc8f

remove old sense files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio_cached_examples/
senses/all_vecs_mtx.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f0c9de5688dd793470c40ebc3b49c29be6ddbf9a38804bca64512940671e129
3
- size 2470232826
 
 
 
 
senses/lm_head.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f94054e64b4d1a07e18443769df4d3b9e346c00b02ffe4e9579e8313034dac24
3
- size 154411755
 
 
 
 
senses/use_senses.py DELETED
@@ -1,44 +0,0 @@
1
- """Visualize some sense vectors"""
2
-
3
- import torch
4
- import argparse
5
-
6
- import transformers
7
-
8
- def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None):
9
- """
10
- Prints out the top-scoring words (and lowest-scoring words) for each sense.
11
-
12
- """
13
- if contents is None:
14
- print(word)
15
- token_id = tokenizer(word)['input_ids'][0]
16
- contents = vecs[token_id] # torch.Size([16, 768])
17
-
18
- for i in range(contents.shape[0]):
19
- print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i))
20
- logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
21
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
- print('~~~Positive~~~')
23
- for j in range(count):
24
- print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item()))
25
- print('~~~Negative~~~')
26
- for j in range(count):
27
- print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item()))
28
- return contents
29
- print()
30
- print()
31
- print()
32
-
33
- argp = argparse.ArgumentParser()
34
- argp.add_argument('vecs_path')
35
- argp.add_argument('lm_head_path')
36
- args = argp.parse_args()
37
-
38
- # Load tokenizer and parameters
39
- tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
40
- vecs = torch.load(args.vecs_path)
41
- lm_head = torch.load(args.lm_head_path)
42
-
43
- visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)
44
-