medclip-roco / tools /create_embeddings.py
kaushalya's picture
Add documentation
aa31199
raw history blame
No virus
1.33 kB
import os
import jax
from transformers import AutoTokenizer, CLIPProcessor
from configuration_hybrid_clip import HybridCLIPConfig
from modeling_hybrid_clip import FlaxHybridCLIP
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import Resize, Normalize, ConvertImageDtype, ToTensor
import numpy as np
import pandas as pd
def main():
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
vision_model_name = "openai/clip-vit-base-patch32"
img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images"
processor = CLIPProcessor.from_pretrained(vision_model_name)
img_list = os.listdir(img_dir)
embeddings = []
for idx, img_path in enumerate(img_list):
if idx % 10 == 0:
print(f"{idx} images processed")
img = Image.open(os.path.join(img_dir, img_path)).convert('RGB')
inputs = processor(images=img, return_tensors="jax", padding=True)
inputs['pixel_values'] = inputs['pixel_values'].transpose(0, 2, 3, 1)
img_vec = model.get_image_features(**inputs)
img_vec = np.array(img_vec).reshape(-1).tolist()
embeddings.append(img_vec)
if __name__=='__main__':
main()