File size: 2,534 Bytes
234d9d6
 
7a990e9
 
 
 
 
234d9d6
f140b23
 
7a990e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234d9d6
 
 
 
9a64f12
 
 
234d9d6
 
 
 
 
f140b23
 
2dff389
f140b23
 
 
 
c48dde4
7a990e9
 
 
 
 
 
b31c0c9
7a990e9
5a15dbc
234d9d6
9a64f12
2dff389
 
 
 
234d9d6
 
7068af2
 
234d9d6
 
7a990e9
52fb80a
234d9d6
7068af2
a91f98a
234d9d6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os

from pathlib import Path
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import torch
from torch import nn
import gradio as gr
import requests

LABELS = Path('class_names.txt').read_text().splitlines()
class_model = nn.Sequential(
    nn.Conv2d(1, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(64, 128, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(1152, 256),
    nn.ReLU(),
    nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
class_model.load_state_dict(state_dict, strict=False)
class_model.eval()


model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df =  pd.read_csv('data2.csv')
embeddings_npy = np.load('embeddings.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
  
def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs)
    
def download_img(path):
    img_data = requests.get(path).content
    local_path = path.split("/")[-1]
    with open(local_path, 'wb') as handler:
        handler.write(img_data)
    return local_path
    
def predict(im):
    x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
    with torch.no_grad():
        out = class_model(x)
    probabilities = torch.nn.functional.softmax(out[0], dim=0)
    values, indices = torch.topk(probabilities, 5)
    
    query = LABELS[indices[0]]

    n_results=3
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    outputs = [download_img(df.iloc[i]['path']) for i in results]
    outputs.insert(0, {LABELS[i]: v.item() for i, v in zip(indices, values)})
    print(outputs)
    return outputs

title = "Draw to Search"
description = "Using the power of CLIP and a simple small CNN, find images from movies based on what you draw!"

iface = gr.Interface(
  fn=predict, 
  inputs='sketchpad',
  outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
  title=title,
  description=description,
  live=True
)
iface.launch(debug=True)