File size: 2,082 Bytes
7369efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import os
from transformers import AutoTokenizer
from jax import numpy as jnp
import json
import requests
import zipfile
import io
import natsort
from PIL import Image as PilImage
from tqdm import tqdm


class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def get_image_name(self, idx):
        return self.total_imgs[idx]

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = PilImage.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image


def text_encoder(text, model, tokenizer):
    inputs = tokenizer(
        [text],
        max_length=96,
        truncation=True,
        padding="max_length",
        return_tensors="np",
    )
    embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
        0
    ]
    embedding /= jnp.linalg.norm(embedding)
    return jnp.expand_dims(embedding, axis=0)


def precompute_image_features(model, loader):
    image_features = []
    for i, (images) in enumerate(tqdm(loader)):
        images = images.permute(0, 2, 3, 1).numpy()
        features = model.get_image_features(
            images,
        )
        features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
        image_features.extend(features)
    return jnp.array(image_features)


def find_image(text_query, model, dataset, tokenizer, image_features, n=1):
    zeroshot_weights = text_encoder(text_query, model, tokenizer)
    zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
    distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
    file_paths = []
    for i in range(1, n + 1):
        idx = jnp.argsort(distances, axis=0)[-i, 0]
        file_paths.append("photos/" + dataset.get_image_name(idx))
    return file_paths