Spaces:
Running
Running
from tqdm import tqdm | |
from argparse import ArgumentParser | |
from jax import numpy as jnp | |
from torchvision import datasets, transforms | |
from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor | |
from torchvision.transforms.functional import InterpolationMode | |
from transformers import AutoTokenizer | |
from modeling_hybrid_clip import FlaxHybridCLIP | |
import utils | |
import torch | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("in_dir") | |
parser.add_argument("out_file") | |
args = parser.parse_args() | |
model = FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") | |
tokenizer = AutoTokenizer.from_pretrained( | |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True | |
) | |
image_size = model.config.vision_config.image_size | |
val_preprocess = transforms.Compose( | |
[ | |
Resize([image_size], interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
ToTensor(), | |
Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
dataset = utils.CustomDataSet(args.in_dir, transform=val_preprocess) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=256, | |
shuffle=False, | |
num_workers=16, | |
drop_last=False, | |
) | |
image_features = utils.precompute_image_features(model, loader) | |
jnp.save(f"static/features/{args.out_file}", image_features) | |