bangla-clip / image_search.py
zabir-nabil's picture
Update image_search.py
686e8f8
raw
history blame contribute delete
No virus
3.92 kB
### search image
from multiprocessing import process
import torch
from transformers import AutoTokenizer
import config as CFG
from CLIP_model import CLIPModel
import cv2
import os
import torch
from glob import glob
import albumentations as A
import torch.nn.functional as F
def load_model(device, model_path):
"""load model and tokenizer"""
model = CLIPModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
return model, tokenizer
def process_image(img_path):
imgs = []
for ip in img_path:
transforms_infer = A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
])
image = cv2.imread(ip)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transforms_infer(image=image)['image']
image = torch.tensor(image).permute(2, 0, 1).float()
print(image.shape)
imgs.append(image)
imgs = torch.stack(imgs)
return imgs
def process_text(caption, tokenizer):
caption = tokenizer(caption, padding = True)
e_text = torch.Tensor(caption["input_ids"]).long()
mask = torch.Tensor(caption["attention_mask"]).long()
return e_text, mask
def search_images(search_text, image_path, k = 1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer = load_model(device, model_path="clip_bangla.pt")
image_filenames = glob(image_path + "/*.jpg") + glob(image_path + "/*.JPEG") + glob(image_path + "/*.JPG") + glob(image_path + "/*.png") + glob(image_path + "/*.bmp")
print(f"Searching in image database >> {image_filenames}")
imgs = process_image(image_filenames)
if type(search_text) != list:
search_text = [search_text]
e_text, mask = process_text(search_text, tokenizer)
with torch.no_grad():
imgs = imgs.to(device)
e_text = e_text.to(device)
mask = mask.to(device)
img_embeddings = model.get_image_embeddings(imgs)
text_embeddings = model.get_text_embeddings(e_text, mask)
image_embeddings_n = F.normalize(img_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
print(dot_similarity.shape)
top_k_vals, top_k_indices = torch.topk(dot_similarity.detach().cpu(), min(k, len(image_filenames)))
top_k_vals = top_k_vals.flatten()
top_k_indices = top_k_indices.flatten()
print(top_k_indices)
print(top_k_vals)
### log
images_ret = []
scores_ret = []
for i in range(len(top_k_indices)):
print(f"{image_filenames[int(top_k_indices[i])]} :: {top_k_vals[i]}")
images_ret.append(image_filenames[int(top_k_indices[i])])
scores_ret.append(float(top_k_vals[i]))
return images_ret, scores_ret, top_k_vals, top_k_indices
if __name__ == '__main__':
img_filenames = ["demo_images/1280px-Cox's_Bazar_Sunset.JPG", "demo_images/Cox's_Bazar,_BangladeshThe_sea_is_calm.jpg", "demo_images/Panta_Vaat_Hilsha_Fisha_VariousVarta_2012.JPG",
"demo_images/Pohela_boishakh_10.jpg", "demo_images/Sundarban_Tiger.jpg"]
captions = ["সমুদ্র সৈকতের তীরে সূর্যাস্ত", "গাঢ় নীল সমুদ্র ও এক রাশি মেঘ", "পান্তা ভাত ইলিশ ও মজার খাবার", "এক দল মানুষ পহেলা বৈশাখে নাগর দোলায় চড়তে এসেছে", "সুন্দরবনের নদীর পাশে একটি বাঘ"]
search_images("সমুদ্র সৈকতের তীরে সূর্যাস্ত", "demo/", k = 10)