hushell commited on
Commit
1aa70ca
1 Parent(s): e337a53

fix curse issue

Browse files
Files changed (1) hide show
  1. app.py +26 -29
app.py CHANGED
@@ -4,7 +4,6 @@ import time
4
  import random
5
  import torch
6
  import torchvision.transforms as transforms
7
- #import requests
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
@@ -12,16 +11,13 @@ from models import get_model
12
  from dotmap import DotMap
13
  from PIL import Image
14
 
15
-
16
- os.environ['TERM'] = 'linux'
17
- os.environ['TERMINFO'] = '/etc/terminfo'
18
-
19
 
20
  # args
21
  args = DotMap()
22
  args.deploy = 'vanilla'
23
  args.arch = 'dino_small_patch16'
24
- #args.resume = '/fast_scratch/hushell/fluidstack/FS125_few-shot-transformer/outputs/dinosmall_1e-4/best_converted.pth'
25
  args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
26
  args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
27
  args.cx = '06d75168141bc47f1'
@@ -31,7 +27,6 @@ args.cx = '06d75168141bc47f1'
31
  device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  model = get_model(args)
33
  model.to(device)
34
- #checkpoint = torch.load(args.resume, map_location='cpu')
35
  checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
36
  model.load_state_dict(checkpoint['model'], strict=True)
37
 
@@ -63,6 +58,12 @@ def denormalize(x, mean, std):
63
  # Google image search
64
  from google_images_search import GoogleImagesSearch
65
 
 
 
 
 
 
 
66
  # define search params
67
  # option for commonly used search param are shown below for easy reference.
68
  # For param marked with '##':
@@ -90,7 +91,6 @@ def inference(query, labels, n_supp=10):
90
  labels = labels.split(',')
91
  n_supp = int(n_supp)
92
 
93
- #print(f'#rows={len(labels)}, #cols={n_supp}')
94
  fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
95
 
96
  with torch.no_grad():
@@ -102,26 +102,24 @@ def inference(query, labels, n_supp=10):
102
 
103
  # search support images
104
  for idx, y in enumerate(labels):
105
- with GoogleImagesSearch(args.api_key, args.cx) as gis:
106
- _search_params['q'] = y
107
- _search_params['num'] = n_supp
108
- gis.search(search_params=_search_params, custom_image_name='my_image')
109
- gis._custom_image_name = 'my_image'
110
-
111
- for j, x in enumerate(gis.results()):
112
- #url = x.url
113
- #x_im = Image.open(requests.get(url, stream=True).raw)
114
- x.download('./')
115
- x_im = Image.open(x.path)
116
-
117
- # vis
118
- axs[idx, j].imshow(x_im)
119
- axs[idx, j].set_title(f'{y}{j}')
120
- axs[idx, j].axis('off')
121
-
122
- x_im = preprocess(x_im) # (3, H, W)
123
- supp_x.append(x_im)
124
- supp_y.append(idx)
125
 
126
  print('Searching for support images is done.')
127
 
@@ -148,7 +146,6 @@ gr.Interface(fn=inference,
148
  inputs=[
149
  gr.inputs.Image(label="Image to classify", type="pil"),
150
  gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
151
- #gr.inputs.Number(default=1, label="Number of support examples from Google")
152
  gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
153
  ],
154
  theme="grass",
 
4
  import random
5
  import torch
6
  import torchvision.transforms as transforms
 
7
  import gradio as gr
8
  import matplotlib.pyplot as plt
9
 
 
11
  from dotmap import DotMap
12
  from PIL import Image
13
 
14
+ #os.environ['TERM'] = 'linux'
15
+ #os.environ['TERMINFO'] = '/etc/terminfo'
 
 
16
 
17
  # args
18
  args = DotMap()
19
  args.deploy = 'vanilla'
20
  args.arch = 'dino_small_patch16'
 
21
  args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
22
  args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
23
  args.cx = '06d75168141bc47f1'
 
27
  device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  model = get_model(args)
29
  model.to(device)
 
30
  checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
31
  model.load_state_dict(checkpoint['model'], strict=True)
32
 
 
58
  # Google image search
59
  from google_images_search import GoogleImagesSearch
60
 
61
+ class MyGIS(GoogleImagesSearch):
62
+ def __enter__(self):
63
+ return self
64
+ def __exit__(self, exc_type, exc_val, exc_tb):
65
+ return
66
+
67
  # define search params
68
  # option for commonly used search param are shown below for easy reference.
69
  # For param marked with '##':
 
91
  labels = labels.split(',')
92
  n_supp = int(n_supp)
93
 
 
94
  fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
95
 
96
  with torch.no_grad():
 
102
 
103
  # search support images
104
  for idx, y in enumerate(labels):
105
+ gis = GoogleImagesSearch(args.api_key, args.cx)
106
+ _search_params['q'] = y
107
+ _search_params['num'] = n_supp
108
+ gis.search(search_params=_search_params, custom_image_name='my_image')
109
+ gis._custom_image_name = 'my_image'
110
+
111
+ for j, x in enumerate(gis.results()):
112
+ x.download('./')
113
+ x_im = Image.open(x.path)
114
+
115
+ # vis
116
+ axs[idx, j].imshow(x_im)
117
+ axs[idx, j].set_title(f'{y}{j}')
118
+ axs[idx, j].axis('off')
119
+
120
+ x_im = preprocess(x_im) # (3, H, W)
121
+ supp_x.append(x_im)
122
+ supp_y.append(idx)
 
 
123
 
124
  print('Searching for support images is done.')
125
 
 
146
  inputs=[
147
  gr.inputs.Image(label="Image to classify", type="pil"),
148
  gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
 
149
  gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
150
  ],
151
  theme="grass",