koclip / executables /embed_images.py
jaketae's picture
chore: mv executables to separate directory
8259f2d
raw history blame
No virus
2.01 kB
import argparse
import csv
import os
import jax.numpy as jnp
from jax import jit
from PIL import Image
from tqdm import tqdm
from utils import load_model
def main(args):
root = args.image_path
files = list(os.listdir(root))
for f in files:
assert f[-4:] == ".jpg"
for model_name in ["koclip-base", "koclip-large"]:
model, processor = load_model(f"koclip/{model_name}")
with tqdm(total=len(files)) as pbar:
for counter in range(0, len(files), args.batch_size):
images = []
image_ids = []
for idx in range(counter, min(len(files), counter + args.batch_size)):
file_ = files[idx]
image = Image.open(os.path.join(root, file_)).convert("RGB")
images.append(image)
image_ids.append(file_)
pbar.update(args.batch_size)
try:
inputs = processor(
text=[""], images=images, return_tensors="jax", padding=True
)
except:
print(image_ids)
break
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
features = model(**inputs).image_embeds
with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
writer = csv.writer(f, delimiter="\t")
for image_id, feature in zip(image_ids, features):
writer.writerow(
[image_id, ",".join(map(lambda x: str(x), feature))]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=16)
parser.add_argument("--image_path", default="images")
parser.add_argument("--out_path", default="features")
args = parser.parse_args()
main(args)