import argparse import csv import os import jax.numpy as jnp from jax import jit from PIL import Image from tqdm import tqdm from utils import load_model def main(args): root = args.image_path files = list(os.listdir(root)) for f in files: assert f[-4:] == ".jpg" for model_name in ["koclip-base", "koclip-large"]: model, processor = load_model(f"koclip/{model_name}") with tqdm(total=len(files)) as pbar: for counter in range(0, len(files), args.batch_size): images = [] image_ids = [] for idx in range(counter, min(len(files), counter + args.batch_size)): file_ = files[idx] image = Image.open(os.path.join(root, file_)).convert("RGB") images.append(image) image_ids.append(file_) pbar.update(args.batch_size) try: inputs = processor( text=[""], images=images, return_tensors="jax", padding=True ) except: print(image_ids) break inputs["pixel_values"] = jnp.transpose( inputs["pixel_values"], axes=[0, 2, 3, 1] ) features = model(**inputs).image_embeds with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f: writer = csv.writer(f, delimiter="\t") for image_id, feature in zip(image_ids, features): writer.writerow( [image_id, ",".join(map(lambda x: str(x), feature))] ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", default=16) parser.add_argument("--image_path", default="images") parser.add_argument("--out_path", default="features") args = parser.parse_args() main(args)