Spaces:
Build error
Build error
fix curse issue
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import time
|
|
4 |
import random
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
7 |
-
#import requests
|
8 |
import gradio as gr
|
9 |
import matplotlib.pyplot as plt
|
10 |
|
@@ -12,16 +11,13 @@ from models import get_model
|
|
12 |
from dotmap import DotMap
|
13 |
from PIL import Image
|
14 |
|
15 |
-
|
16 |
-
os.environ['
|
17 |
-
os.environ['TERMINFO'] = '/etc/terminfo'
|
18 |
-
|
19 |
|
20 |
# args
|
21 |
args = DotMap()
|
22 |
args.deploy = 'vanilla'
|
23 |
args.arch = 'dino_small_patch16'
|
24 |
-
#args.resume = '/fast_scratch/hushell/fluidstack/FS125_few-shot-transformer/outputs/dinosmall_1e-4/best_converted.pth'
|
25 |
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
|
26 |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
|
27 |
args.cx = '06d75168141bc47f1'
|
@@ -31,7 +27,6 @@ args.cx = '06d75168141bc47f1'
|
|
31 |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
model = get_model(args)
|
33 |
model.to(device)
|
34 |
-
#checkpoint = torch.load(args.resume, map_location='cpu')
|
35 |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
|
36 |
model.load_state_dict(checkpoint['model'], strict=True)
|
37 |
|
@@ -63,6 +58,12 @@ def denormalize(x, mean, std):
|
|
63 |
# Google image search
|
64 |
from google_images_search import GoogleImagesSearch
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# define search params
|
67 |
# option for commonly used search param are shown below for easy reference.
|
68 |
# For param marked with '##':
|
@@ -90,7 +91,6 @@ def inference(query, labels, n_supp=10):
|
|
90 |
labels = labels.split(',')
|
91 |
n_supp = int(n_supp)
|
92 |
|
93 |
-
#print(f'#rows={len(labels)}, #cols={n_supp}')
|
94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
95 |
|
96 |
with torch.no_grad():
|
@@ -102,26 +102,24 @@ def inference(query, labels, n_supp=10):
|
|
102 |
|
103 |
# search support images
|
104 |
for idx, y in enumerate(labels):
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
supp_x.append(x_im)
|
124 |
-
supp_y.append(idx)
|
125 |
|
126 |
print('Searching for support images is done.')
|
127 |
|
@@ -148,7 +146,6 @@ gr.Interface(fn=inference,
|
|
148 |
inputs=[
|
149 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
150 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
151 |
-
#gr.inputs.Number(default=1, label="Number of support examples from Google")
|
152 |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
|
153 |
],
|
154 |
theme="grass",
|
|
|
4 |
import random
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
|
|
7 |
import gradio as gr
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
|
|
11 |
from dotmap import DotMap
|
12 |
from PIL import Image
|
13 |
|
14 |
+
#os.environ['TERM'] = 'linux'
|
15 |
+
#os.environ['TERMINFO'] = '/etc/terminfo'
|
|
|
|
|
16 |
|
17 |
# args
|
18 |
args = DotMap()
|
19 |
args.deploy = 'vanilla'
|
20 |
args.arch = 'dino_small_patch16'
|
|
|
21 |
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
|
22 |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
|
23 |
args.cx = '06d75168141bc47f1'
|
|
|
27 |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
model = get_model(args)
|
29 |
model.to(device)
|
|
|
30 |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
|
31 |
model.load_state_dict(checkpoint['model'], strict=True)
|
32 |
|
|
|
58 |
# Google image search
|
59 |
from google_images_search import GoogleImagesSearch
|
60 |
|
61 |
+
class MyGIS(GoogleImagesSearch):
|
62 |
+
def __enter__(self):
|
63 |
+
return self
|
64 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
65 |
+
return
|
66 |
+
|
67 |
# define search params
|
68 |
# option for commonly used search param are shown below for easy reference.
|
69 |
# For param marked with '##':
|
|
|
91 |
labels = labels.split(',')
|
92 |
n_supp = int(n_supp)
|
93 |
|
|
|
94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
95 |
|
96 |
with torch.no_grad():
|
|
|
102 |
|
103 |
# search support images
|
104 |
for idx, y in enumerate(labels):
|
105 |
+
gis = GoogleImagesSearch(args.api_key, args.cx)
|
106 |
+
_search_params['q'] = y
|
107 |
+
_search_params['num'] = n_supp
|
108 |
+
gis.search(search_params=_search_params, custom_image_name='my_image')
|
109 |
+
gis._custom_image_name = 'my_image'
|
110 |
+
|
111 |
+
for j, x in enumerate(gis.results()):
|
112 |
+
x.download('./')
|
113 |
+
x_im = Image.open(x.path)
|
114 |
+
|
115 |
+
# vis
|
116 |
+
axs[idx, j].imshow(x_im)
|
117 |
+
axs[idx, j].set_title(f'{y}{j}')
|
118 |
+
axs[idx, j].axis('off')
|
119 |
+
|
120 |
+
x_im = preprocess(x_im) # (3, H, W)
|
121 |
+
supp_x.append(x_im)
|
122 |
+
supp_y.append(idx)
|
|
|
|
|
123 |
|
124 |
print('Searching for support images is done.')
|
125 |
|
|
|
146 |
inputs=[
|
147 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
148 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
|
|
149 |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
|
150 |
],
|
151 |
theme="grass",
|