osanseviero HF staff commited on
Commit
7a990e9
β€’
1 Parent(s): 98b23d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -7
app.py CHANGED
@@ -1,11 +1,32 @@
1
- import pandas as pd, numpy as np
2
  import os
3
- from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
4
 
 
 
 
 
 
5
  import gradio as gr
6
  import requests
7
 
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -26,20 +47,27 @@ def download_img(path):
26
  return local_path
27
 
28
  def predict(query):
 
 
 
 
 
 
 
 
29
  n_results=3
30
  text_embeddings = compute_text_embeddings([query]).detach().numpy()
31
  results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
32
  paths = [download_img(df.iloc[i]['path']) for i in results]
33
  print(paths)
34
- return paths
35
 
36
  title = "Draw to Search"
37
  iface = gr.Interface(
38
  fn=predict,
39
- inputs=[gr.inputs.Textbox(label="text", lines=3)],
40
- outputs=[gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
41
  title=title,
42
- examples=[["Sunset"]]
43
  )
44
  iface.launch(debug=True)
45
 
 
 
1
  import os
 
2
 
3
+ from pathlib import Path
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
6
+ import torch
7
+ from torch import nn
8
  import gradio as gr
9
  import requests
10
 
11
+ LABELS = Path('class_names.txt').read_text().splitlines()
12
+ class_model = nn.Sequential(
13
+ nn.Conv2d(1, 32, 3, padding='same'),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2),
16
+ nn.Conv2d(32, 64, 3, padding='same'),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(2),
19
+ nn.Conv2d(64, 128, 3, padding='same'),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(2),
22
+ nn.Flatten(),
23
+ nn.Linear(1152, 256),
24
+ nn.ReLU(),
25
+ nn.Linear(256, len(LABELS)),
26
+ )
27
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
28
+ class_model.load_state_dict(state_dict, strict=False)
29
+ class_model.eval()
30
 
31
 
32
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
47
  return local_path
48
 
49
  def predict(query):
50
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
51
+ with torch.no_grad():
52
+ out = class_model(x)
53
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
54
+ values, indices = torch.topk(probabilities, 5)
55
+
56
+ query = values[0]
57
+
58
  n_results=3
59
  text_embeddings = compute_text_embeddings([query]).detach().numpy()
60
  results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
61
  paths = [download_img(df.iloc[i]['path']) for i in results]
62
  print(paths)
63
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}, paths
64
 
65
  title = "Draw to Search"
66
  iface = gr.Interface(
67
  fn=predict,
68
+ inputs='sketchpad',
69
+ outputs=[outputs='label', gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
70
  title=title,
 
71
  )
72
  iface.launch(debug=True)
73