import argparse import csv import os from PIL import Image from utils import load_model def main(args): root = args.image_path files = list(os.listdir(root)) for model_name in ["koclip", "koclip/koclip-large"]: counter = 0 images = [] image_ids = [] model, processor = load_model(f"koclip/{model_name}") while counter < len(files): if counter != 0 and counter % args.batch_size == 0: inputs = processor(text=[""], images=images, return_tensors="jax", padding=True) features = model(**inputs).image_embeds with open(os.path.join(args.out_path, f"{model_name}.tsv", "w+")) as f: writer = csv.writer(f, delimiter="\t") for image_id, feature in zip(image_ids, features): writer.writerow([image_id, ",".join(feature)]) images = [] image_ids = [] else: file_ = files[counter] image = Image.open(os.path.join(root, file_)) images.append(image) image_ids.append(file_) counter += 1 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", default=16) parser.add_argument("--image_path", default="images") parser.add_argument("--out_path", default="features") args = parser.parse_args() main(args)