import argparse import csv import os from PIL import Image from utils import load_model import jax.numpy as jnp from jax import jit from tqdm import tqdm def main(args): root = args.image_path files = list(os.listdir(root)) for f in files: assert(f[-4:] == ".jpg") for model_name in ["koclip", "koclip-large"]: model, processor = load_model(f"koclip/{model_name}") with tqdm(total=len(files)) as pbar: for counter in range(0, len(files), args.batch_size): images = [] image_ids = [] for idx in range(counter, min(len(files), counter + args.batch_size)): file_ = files[idx] image = Image.open(os.path.join(root, file_)).convert('RGB') images.append(image) image_ids.append(file_) pbar.update(args.batch_size) try: inputs = processor(text=[""], images=images, return_tensors="jax", padding=True) except: print(image_ids) break inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1]) features = model(**inputs).image_embeds with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f: writer = csv.writer(f, delimiter="\t") for image_id, feature in zip(image_ids, features): writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))]) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", default=16) parser.add_argument("--image_path", default="images") parser.add_argument("--out_path", default="features") args = parser.parse_args() main(args)