import os import jax from transformers import AutoTokenizer, CLIPProcessor from configuration_hybrid_clip import HybridCLIPConfig from modeling_hybrid_clip import FlaxHybridCLIP from PIL import Image import matplotlib.pyplot as plt import torch import torchvision from torchvision.transforms.functional import InterpolationMode from torchvision.transforms import Resize, Normalize, ConvertImageDtype, ToTensor import numpy as np import pandas as pd def main(): model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco") vision_model_name = "openai/clip-vit-base-patch32" img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images" processor = CLIPProcessor.from_pretrained(vision_model_name) img_list = os.listdir(img_dir) embeddings = [] for idx, img_path in enumerate(img_list): if idx % 10 == 0: print(f"{idx} images processed") img = Image.open(os.path.join(img_dir, img_path)).convert('RGB') inputs = processor(images=img, return_tensors="jax", padding=True) inputs['pixel_values'] = inputs['pixel_values'].transpose(0, 2, 3, 1) img_vec = model.get_image_features(**inputs) img_vec = np.array(img_vec).reshape(-1).tolist() embeddings.append(img_vec) if __name__=='__main__': main()