File size: 2,798 Bytes
6bea3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import gc
import cv2
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
from implement import *
import config as CFG
from main import build_loaders
from CLIP import CLIPModel
import os
os.environ['HTTPS_PROXY']="http://185.46.212.90:80/"
os.environ['HTTP_PROXY']="http://185.46.212.90:80/"
with gr.Blocks(css="style.css") as demo:
    def get_image_embeddings(valid_df, model_path):
        tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
        valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
        
        model = CLIPModel().to(CFG.device)
        model.load_state_dict(torch.load(model_path, map_location=CFG.device))
        model.eval()
        
        valid_image_embeddings = []
        with torch.no_grad():
            for batch in tqdm(valid_loader):
                image_features = model.image_encoder(batch["image"].to(CFG.device))
                image_embeddings = model.image_projection(image_features)
                valid_image_embeddings.append(image_embeddings)
        return model, torch.cat(valid_image_embeddings)

    _, valid_df = make_train_valid_dfs()
    model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

    def find_matches(query, n=9):
        tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
        encoded_query = tokenizer([query])
        batch = {
            key: torch.tensor(values).to(CFG.device)
            for key, values in encoded_query.items()
        }
        with torch.no_grad():
            text_features = model.text_encoder(
                input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
            )
            text_embeddings = model.text_projection(text_features)
        
        image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        dot_similarity = text_embeddings_n @ image_embeddings_n.T
        
        _, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
        matches = [valid_df['image'].values[idx] for idx in indices[::5]]
        
        images = []
        for match in matches:
            image = cv2.imread(f"{CFG.image_path}/{match}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # images.append(image)
        
        return image
    with gr.Row():
        textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
        image = gr.Image(type="numpy")
    
    button = gr.Button("Press")
    button.click(
        fn = find_matches,
        inputs=textbox,
        outputs=image
    )
    
    # Create Gradio interface
demo.launch(share=True)