|
from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer |
|
import json |
|
import torch |
|
from torch import nn |
|
import tqdm |
|
|
|
embed_dict = nn.ParameterDict() |
|
bsz = 8 |
|
|
|
with open("id_to_str.json") as f: |
|
data = json.load(f) |
|
|
|
keys = list(data.keys()) |
|
bar = tqdm.tqdm(range(len(keys)//bsz)) |
|
|
|
proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") |
|
tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble") |
|
model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble") |
|
|
|
for i in bar: |
|
batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]] |
|
tokenized = tokenizer(batch) |
|
for k in range(bsz): |
|
if len(tokenized[k]) > 16: |
|
tokenizer.decode(tokenized[k]) |
|
|
|
batch = proc(text=batch, return_tensors="pt") |
|
output = model(**batch) |
|
for k, key in enumerate(keys[i*bsz:(i+1)*bsz]): |
|
embed_dict[key] = output.pooler_output[k, :] |
|
|
|
torch.save(embed_dict.state_dict(), "embeds.pt") |