clip-rsicd-demo / demo-image-encoder.py
sujitpal's picture
new: initial revision (copied from main repo)
357b0b8
raw
history blame
2.23 kB
import argparse
import jax
import jax.numpy as jnp
import json
import matplotlib.pyplot as plt
import numpy as np
import requests
import os
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
def encode_image(image_file, model, processor):
image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
inputs = processor(images=image, return_tensors="jax")
image_vec = model.get_image_features(**inputs)
return np.array(image_vec).reshape(-1)
DATA_DIR = "/home/shared/data"
IMAGES_DIR = os.path.join(DATA_DIR, "rsicd_images")
CAPTIONS_FILE = os.path.join(DATA_DIR, "dataset_rsicd.json")
VECTORS_DIR = os.path.join(DATA_DIR, "vectors")
BASELINE_MODEL = "openai/clip-vit-base-patch32"
parser = argparse.ArgumentParser()
parser.add_argument("model_dir", help="Path to model to use for encoding")
args = parser.parse_args()
print("Loading image list...", end="")
image2captions = {}
with open(CAPTIONS_FILE, "r") as fcap:
data = json.loads(fcap.read())
for image in data["images"]:
if image["split"] == "test":
filename = image["filename"]
sentences = []
for sentence in image["sentences"]:
sentences.append(sentence["raw"])
image2captions[filename] = sentences
print("{:d} images".format(len(image2captions)))
print("Loading model...")
if args.model_dir == "baseline":
model = FlaxCLIPModel.from_pretrained(BASELINE_MODEL)
else:
model = FlaxCLIPModel.from_pretrained(args.model_dir)
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
model_basename = "-".join(args.model_dir.split("/")[-2:])
vector_file = os.path.join(VECTORS_DIR, "test-{:s}.tsv".format(model_basename))
print("Vectors written to {:s}".format(vector_file))
num_written = 0
fvec = open(vector_file, "w")
for image_file in image2captions.keys():
if num_written % 100 == 0:
print("{:d} images processed".format(num_written))
image_vec = encode_image(image_file, model, processor)
image_vec_s = ",".join(["{:.7e}".format(x) for x in image_vec])
fvec.write("{:s}\t{:s}\n".format(image_file, image_vec_s))
num_written += 1
print("{:d} images processed, COMPLETE".format(num_written))
fvec.close()