File size: 1,522 Bytes
0f27d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from tqdm import tqdm
from argparse import ArgumentParser
from jax import numpy as jnp
from torchvision import datasets, transforms
from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
from modeling_hybrid_clip import FlaxHybridCLIP
import utils
import torch


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("in_dir")
    parser.add_argument("out_file")
    args = parser.parse_args()

    model = FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")

    tokenizer = AutoTokenizer.from_pretrained(
        "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
    )

    image_size = model.config.vision_config.image_size

    val_preprocess = transforms.Compose(
        [
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ToTensor(),
            Normalize(
                (0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )

    dataset = utils.CustomDataSet(args.in_dir, transform=val_preprocess)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=256,
        shuffle=False,
        num_workers=16,
        drop_last=False,
    )

    image_features = utils.precompute_image_features(model, loader)
    jnp.save(f"static/features/{args.out_file}", image_features)