File size: 1,329 Bytes
aa31199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import jax

from transformers import AutoTokenizer, CLIPProcessor
from configuration_hybrid_clip import HybridCLIPConfig
from modeling_hybrid_clip import FlaxHybridCLIP
from PIL import Image

import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms import Resize, Normalize, ConvertImageDtype, ToTensor
import numpy as np
import pandas as pd


def main():
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
    vision_model_name = "openai/clip-vit-base-patch32"
    img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images"

    processor = CLIPProcessor.from_pretrained(vision_model_name)

    img_list = os.listdir(img_dir)
    embeddings = []

    for idx, img_path in enumerate(img_list):
        if idx % 10 == 0:
            print(f"{idx} images processed")
        img = Image.open(os.path.join(img_dir, img_path)).convert('RGB')
        inputs = processor(images=img, return_tensors="jax", padding=True)
        inputs['pixel_values'] = inputs['pixel_values'].transpose(0, 2, 3, 1)
        img_vec = model.get_image_features(**inputs)
        img_vec = np.array(img_vec).reshape(-1).tolist()
        embeddings.append(img_vec)
        
if __name__=='__main__':
    main()