File size: 1,661 Bytes
875ae89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

class ImageDataset(Dataset):
    def __init__(self, image_dir, processor):
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        return self.processor(images=image, return_tensors="pt")['pixel_values'][0]

def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
    # Load the CLIP model and processor
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    # Create dataset and dataloader
    dataset = ImageDataset(image_dir, processor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    all_embeddings = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            embeddings = model.get_image_features(pixel_values=batch)
            all_embeddings.append(embeddings.cpu().numpy())

    return np.concatenate(all_embeddings)

def create_embeddings(image_dir, batch_size, device):
    embeddings = get_clip_embeddings_batch(image_dir, batch_size, device)
    np.save("image_embeddings.npy", embeddings)
    return embeddings