junjuice0 commited on
Commit
607d6f6
1 Parent(s): b9a031f

Upload get_embeds.py

Browse files
Files changed (1) hide show
  1. get_embeds.py +32 -0
get_embeds.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer
2
+ import json
3
+ import torch
4
+ from torch import nn
5
+ import tqdm
6
+
7
+ embed_dict = nn.ParameterDict()
8
+ bsz = 8
9
+
10
+ with open("id_to_str.json") as f:
11
+ data = json.load(f)
12
+
13
+ keys = list(data.keys())
14
+ bar = tqdm.tqdm(range(len(keys)//bsz))
15
+
16
+ proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
17
+ tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble")
18
+ model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble")
19
+
20
+ for i in bar:
21
+ batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]]
22
+ tokenized = tokenizer(batch)
23
+ for k in range(bsz):
24
+ if len(tokenized[k]) > 16:
25
+ tokenizer.decode(tokenized[k])
26
+
27
+ batch = proc(text=batch, return_tensors="pt")
28
+ output = model(**batch)
29
+ for k, key in enumerate(keys[i*bsz:(i+1)*bsz]):
30
+ embed_dict[key] = output.pooler_output[k, :]
31
+
32
+ torch.save(embed_dict.state_dict(), "embeds.pt")