clip-italian-demo / precompute_features.py
g8a9's picture
Add code to precompute embeddings.
0f27d7b
raw
history blame
1.52 kB
from tqdm import tqdm
from argparse import ArgumentParser
from jax import numpy as jnp
from torchvision import datasets, transforms
from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
from modeling_hybrid_clip import FlaxHybridCLIP
import utils
import torch
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("in_dir")
parser.add_argument("out_file")
args = parser.parse_args()
model = FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
tokenizer = AutoTokenizer.from_pretrained(
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
)
image_size = model.config.vision_config.image_size
val_preprocess = transforms.Compose(
[
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ToTensor(),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
dataset = utils.CustomDataSet(args.in_dir, transform=val_preprocess)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=False,
num_workers=16,
drop_last=False,
)
image_features = utils.precompute_image_features(model, loader)
jnp.save(f"static/features/{args.out_file}", image_features)