koclip / embed_images.py
jaketae's picture
style: run linter
2cf3514
raw
history blame
2.01 kB
import argparse
import csv
import os
import jax.numpy as jnp
from jax import jit
from PIL import Image
from tqdm import tqdm
from utils import load_model
def main(args):
root = args.image_path
files = list(os.listdir(root))
for f in files:
assert f[-4:] == ".jpg"
for model_name in ["koclip-base", "koclip-large"]:
model, processor = load_model(f"koclip/{model_name}")
with tqdm(total=len(files)) as pbar:
for counter in range(0, len(files), args.batch_size):
images = []
image_ids = []
for idx in range(counter, min(len(files), counter + args.batch_size)):
file_ = files[idx]
image = Image.open(os.path.join(root, file_)).convert("RGB")
images.append(image)
image_ids.append(file_)
pbar.update(args.batch_size)
try:
inputs = processor(
text=[""], images=images, return_tensors="jax", padding=True
)
except:
print(image_ids)
break
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
features = model(**inputs).image_embeds
with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
writer = csv.writer(f, delimiter="\t")
for image_id, feature in zip(image_ids, features):
writer.writerow(
[image_id, ",".join(map(lambda x: str(x), feature))]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=16)
parser.add_argument("--image_path", default="images")
parser.add_argument("--out_path", default="features")
args = parser.parse_args()
main(args)