File size: 2,008 Bytes
b4968e4
 
 
 
503acf7
 
2cf3514
503acf7
b4968e4
2cf3514
 
b4968e4
 
 
 
8e37dd1
2cf3514
532be59
b4968e4
8e37dd1
 
 
 
 
 
2cf3514
8e37dd1
 
 
 
 
2cf3514
 
 
8e37dd1
 
 
2cf3514
 
 
b4968e4
8e37dd1
b4968e4
 
2cf3514
 
 
b4968e4
 
 
 
 
503acf7
 
b4968e4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import csv
import os

import jax.numpy as jnp
from jax import jit
from PIL import Image
from tqdm import tqdm

from utils import load_model


def main(args):
    root = args.image_path
    files = list(os.listdir(root))
    for f in files:
        assert f[-4:] == ".jpg"
    for model_name in ["koclip-base", "koclip-large"]:
        model, processor = load_model(f"koclip/{model_name}")
        with tqdm(total=len(files)) as pbar:
            for counter in range(0, len(files), args.batch_size):
                images = []
                image_ids = []
                for idx in range(counter, min(len(files), counter + args.batch_size)):
                    file_ = files[idx]
                    image = Image.open(os.path.join(root, file_)).convert("RGB")
                    images.append(image)
                    image_ids.append(file_)

                pbar.update(args.batch_size)
                try:
                    inputs = processor(
                        text=[""], images=images, return_tensors="jax", padding=True
                    )
                except:
                    print(image_ids)
                    break
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                features = model(**inputs).image_embeds
                with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
                    writer = csv.writer(f, delimiter="\t")
                    for image_id, feature in zip(image_ids, features):
                        writer.writerow(
                            [image_id, ",".join(map(lambda x: str(x), feature))]
                        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=16)
    parser.add_argument("--image_path", default="images")
    parser.add_argument("--out_path", default="features")
    args = parser.parse_args()
    main(args)