test / get_embeds.py
junjuice0's picture
Upload get_embeds.py
607d6f6 verified
from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer
import json
import torch
from torch import nn
import tqdm
embed_dict = nn.ParameterDict()
bsz = 8
with open("id_to_str.json") as f:
data = json.load(f)
keys = list(data.keys())
bar = tqdm.tqdm(range(len(keys)//bsz))
proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble")
for i in bar:
batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]]
tokenized = tokenizer(batch)
for k in range(bsz):
if len(tokenized[k]) > 16:
tokenizer.decode(tokenized[k])
batch = proc(text=batch, return_tensors="pt")
output = model(**batch)
for k, key in enumerate(keys[i*bsz:(i+1)*bsz]):
embed_dict[key] = output.pooler_output[k, :]
torch.save(embed_dict.state_dict(), "embeds.pt")