File size: 1,872 Bytes
b4968e4
 
 
 
 
 
 
503acf7
 
 
 
b4968e4
 
 
 
 
8e37dd1
503acf7
532be59
b4968e4
8e37dd1
 
 
 
 
 
503acf7
8e37dd1
 
 
 
 
503acf7
8e37dd1
 
 
503acf7
b4968e4
8e37dd1
b4968e4
 
503acf7
b4968e4
 
 
 
 
503acf7
 
b4968e4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
import csv
import os

from PIL import Image

from utils import load_model
import jax.numpy as jnp
from jax import jit

from tqdm import tqdm


def main(args):
    root = args.image_path
    files = list(os.listdir(root))
    for f in files:
        assert(f[-4:] == ".jpg")
    for model_name in ["koclip-base", "koclip-large"]:
        model, processor = load_model(f"koclip/{model_name}")
        with tqdm(total=len(files)) as pbar:
            for counter in range(0, len(files), args.batch_size):
                images = []
                image_ids = []
                for idx in range(counter, min(len(files), counter + args.batch_size)):
                    file_ = files[idx]
                    image = Image.open(os.path.join(root, file_)).convert('RGB')
                    images.append(image)
                    image_ids.append(file_)

                pbar.update(args.batch_size)
                try:
                    inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
                except:
                    print(image_ids)
                    break
                inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1])
                features = model(**inputs).image_embeds
                with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
                    writer = csv.writer(f, delimiter="\t")
                    for image_id, feature in zip(image_ids, features):
                        writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=16)
    parser.add_argument("--image_path", default="images")
    parser.add_argument("--out_path", default="features")
    args = parser.parse_args()
    main(args)