File size: 2,166 Bytes
e93b7b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch

class SearchEngineModel():

    def __init__(self):
        self.logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.INFO)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = self.load_clip_model()

    def load_clip_model(self):
        model, preprocess = clip.load("ViT-B/32", device=self.device)

        return model, preprocess

    def read_image(self, image_path):
        pil_image = Image.open(image_path)

        return pil_image

    def encode_image(self, model, preprocess, image_path):
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = model.encode_image(image)

        image_features = pd.DataFrame(image_features.numpy())

        return image_features

    def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
        encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
        similarity = encoded_images @ prompt_features.T
        values, indices = similarity.topk(nofimages_to_show, dim=0)
        
        results = []
        for value, index in zip(values, indices):
            results.append(image_paths[index])

        return results

    def search_image_by_text_prompt(self, text_prompt, nofimages_to_show):
        query = clip.tokenize([text_prompt]).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(query)
        
        search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show)

        return search_results

    def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
        image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = self.model.encode_image(image)

        search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)

        return search_results