File size: 2,225 Bytes
1768dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import jax
import jax.numpy as jnp
import json
import matplotlib.pyplot as plt
import numpy as np
import requests
import os

from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel


def encode_image(image_file, model, processor):
    image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
    inputs = processor(images=image, return_tensors="jax")
    image_vec = model.get_image_features(**inputs)
    return np.array(image_vec).reshape(-1)


DATA_DIR = "/home/shared/data"
IMAGES_DIR = os.path.join(DATA_DIR, "rsicd_images")
CAPTIONS_FILE = os.path.join(DATA_DIR, "dataset_rsicd.json")
VECTORS_DIR = os.path.join(DATA_DIR, "vectors")
BASELINE_MODEL = "openai/clip-vit-base-patch32"

parser = argparse.ArgumentParser()
parser.add_argument("model_dir", help="Path to model to use for encoding")
args = parser.parse_args()

print("Loading image list...", end="")
image2captions = {}
with open(CAPTIONS_FILE, "r") as fcap:
    data = json.loads(fcap.read())
for image in data["images"]:
    if image["split"] == "test":
        filename = image["filename"]
        sentences = []
        for sentence in image["sentences"]:
            sentences.append(sentence["raw"])
        image2captions[filename] = sentences
    
print("{:d} images".format(len(image2captions)))


print("Loading model...")
if args.model_dir == "baseline":
    model = FlaxCLIPModel.from_pretrained(BASELINE_MODEL)
else:
    model = FlaxCLIPModel.from_pretrained(args.model_dir)
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)


model_basename = "-".join(args.model_dir.split("/")[-2:])
vector_file = os.path.join(VECTORS_DIR, "test-{:s}.tsv".format(model_basename))
print("Vectors written to {:s}".format(vector_file))
num_written = 0
fvec = open(vector_file, "w")
for image_file in image2captions.keys():
    if num_written % 100 == 0:
        print("{:d} images processed".format(num_written))
    image_vec = encode_image(image_file, model, processor)
    image_vec_s = ",".join(["{:.7e}".format(x) for x in image_vec])
    fvec.write("{:s}\t{:s}\n".format(image_file, image_vec_s))
    num_written += 1
    
print("{:d} images processed, COMPLETE".format(num_written))
fvec.close()