File size: 3,923 Bytes
a2b22c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417d400
a2b22c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22fe311
a2b22c2
22fe311
a2b22c2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
### search image
from multiprocessing import process
import torch
from transformers import AutoTokenizer
import config as CFG
from CLIP_model import CLIPModel
import cv2
import os
import torch
from glob import glob
import albumentations as A
import torch.nn.functional as F


def load_model(device, model_path):
    """load model and tokenizer"""
    model = CLIPModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
    return model, tokenizer

def process_image(img_path):
    imgs = []
    for ip in img_path:
        transforms_infer = A.Compose(
                [
                    A.Resize(CFG.size, CFG.size, always_apply=True),
                    A.Normalize(max_pixel_value=255.0, always_apply=True),
                ])
        image = cv2.imread(ip)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = transforms_infer(image=image)['image']
        image = torch.tensor(image).permute(2, 0, 1).float()
        print(image.shape)
        imgs.append(image)
    imgs = torch.stack(imgs)
    return imgs

def process_text(caption, tokenizer):
    caption = tokenizer(caption, padding = True)
    e_text = torch.Tensor(caption["input_ids"]).long()
    mask = torch.Tensor(caption["attention_mask"]).long()
    return e_text, mask

def search_images(search_text, image_path, k = 1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, tokenizer = load_model(device, model_path="clip_bangla.pt")
    image_filenames = glob(image_path + "/*.jpg") + glob(image_path + "/*.JPEG") + glob(image_path + "/*.JPG") + glob(image_path + "/*.png") + glob(image_path + "/*.bmp")
    print(f"Searching in image database >> {image_filenames}")
    imgs = process_image(image_filenames)

    if type(search_text) != list:
        search_text = [search_text]
    e_text, mask = process_text(search_text, tokenizer)

    with torch.no_grad():
        imgs = imgs.to(device)
        e_text = e_text.to(device)
        mask = mask.to(device)
        img_embeddings = model.get_image_embeddings(imgs)
        text_embeddings = model.get_text_embeddings(e_text, mask)
        image_embeddings_n = F.normalize(img_embeddings, p=2, dim=-1)
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        dot_similarity = text_embeddings_n @ image_embeddings_n.T
        print(dot_similarity.shape)
    top_k_vals, top_k_indices = torch.topk(dot_similarity.detach().cpu(), min(k, len(image_filenames)))
    top_k_vals = top_k_vals.flatten()
    top_k_indices = top_k_indices.flatten()
    print(top_k_indices)
    print(top_k_vals)
    ### log
    images_ret = []
    scores_ret = []
    for i in range(len(top_k_indices)):
        print(f"{image_filenames[int(top_k_indices[i])]} :: {top_k_vals[i]}")
        images_ret.append(image_filenames[int(top_k_indices[i])])
        scores_ret.append(float(top_k_vals[i]))

    return images_ret, scores_ret, top_k_vals, top_k_indices

if __name__ == '__main__':
    img_filenames = ["demo_images/1280px-Cox's_Bazar_Sunset.JPG", "demo_images/Cox's_Bazar,_BangladeshThe_sea_is_calm.jpg", "demo_images/Panta_Vaat_Hilsha_Fisha_VariousVarta_2012.JPG",
                  "demo_images/Pohela_boishakh_10.jpg", "demo_images/Sundarban_Tiger.jpg"]
    captions = ["সমুদ্র সৈকতের তীরে সূর্যাস্ত", "গাঢ় নীল সমুদ্র ও এক রাশি মেঘ", "পান্তা ভাত ইলিশ ও মজার খাবার", "এক দল মানুষ পহেলা বৈশাখে নাগর দোলায় চড়তে এসেছে", "সুন্দরবনের নদীর পাশে একটি বাঘ"]

    search_images("সমুদ্র সৈকতের তীরে সূর্যাস্ত", "demo_images/", k = 10)