hushell commited on
Commit
16564f3
1 Parent(s): ec3f141

add more options for GIS

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  import random
5
  import torch
6
  import torchvision.transforms as transforms
7
- import gradio as gr
8
  import matplotlib.pyplot as plt
9
 
10
  from models import get_model
@@ -83,7 +83,9 @@ _search_params = {
83
 
84
 
85
  # Gradio UI
86
- def inference(query, labels, n_supp=10):
 
 
87
  '''
88
  query: PIL image
89
  labels: list of class names
@@ -91,6 +93,12 @@ def inference(query, labels, n_supp=10):
91
  labels = labels.split(',')
92
  n_supp = int(n_supp)
93
 
 
 
 
 
 
 
94
  fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
95
 
96
  with torch.no_grad():
@@ -104,9 +112,8 @@ def inference(query, labels, n_supp=10):
104
  for idx, y in enumerate(labels):
105
  gis = GoogleImagesSearch(args.api_key, args.cx)
106
  _search_params['q'] = y
107
- _search_params['num'] = n_supp
108
  gis.search(search_params=_search_params, custom_image_name='my_image')
109
- gis._custom_image_name = 'my_image'
110
 
111
  for j, x in enumerate(gis.results()):
112
  x.download('./')
@@ -135,9 +142,10 @@ def inference(query, labels, n_supp=10):
135
 
136
 
137
  # DEBUG
138
- #query = Image.open('../labrador-puppy.jpg')
 
139
  ##labels = 'dog, cat'
140
- #labels = 'girl, boy'
141
  #output = inference(query, labels, n_supp=2)
142
  #print(output)
143
 
@@ -146,7 +154,11 @@ gr.Interface(fn=inference,
146
  inputs=[
147
  gr.inputs.Image(label="Image to classify", type="pil"),
148
  gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
149
- gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
 
 
 
 
150
  ],
151
  theme="grass",
152
  outputs=[
 
4
  import random
5
  import torch
6
  import torchvision.transforms as transforms
7
+ #import gradio as gr
8
  import matplotlib.pyplot as plt
9
 
10
  from models import get_model
 
83
 
84
 
85
  # Gradio UI
86
+ def inference(query, labels, n_supp=10,
87
+ file_type='png', rights='cc_publicdomain',
88
+ image_type='photo', color_type='color'):
89
  '''
90
  query: PIL image
91
  labels: list of class names
 
93
  labels = labels.split(',')
94
  n_supp = int(n_supp)
95
 
96
+ _search_params['num'] = n_supp
97
+ _search_params['fileType'] = file_type
98
+ _search_params['rights'] = rights
99
+ _search_params['imgType'] = image_type
100
+ _search_params['imgColorType'] = color_type
101
+
102
  fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
103
 
104
  with torch.no_grad():
 
112
  for idx, y in enumerate(labels):
113
  gis = GoogleImagesSearch(args.api_key, args.cx)
114
  _search_params['q'] = y
 
115
  gis.search(search_params=_search_params, custom_image_name='my_image')
116
+ gis._custom_image_name = 'my_image' # fix: image name sometimes too long
117
 
118
  for j, x in enumerate(gis.results()):
119
  x.download('./')
 
142
 
143
 
144
  # DEBUG
145
+ ##query = Image.open('../labrador-puppy.jpg')
146
+ #query = Image.open('/Users/hushell/Documents/Dan_tr.png')
147
  ##labels = 'dog, cat'
148
+ #labels = 'girl, sussie'
149
  #output = inference(query, labels, n_supp=2)
150
  #print(output)
151
 
 
154
  inputs=[
155
  gr.inputs.Image(label="Image to classify", type="pil"),
156
  gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
157
+ gr.inputs.Slider(minimum=2, maximum=10, step=1, label="GIS: Number of support examples per class"),
158
+ gr.inputs.Dropdown(['png', 'jpg'], default='png', label='GIS: Image file type'),
159
+ gr.inputs.Dropdown(['cc_publicdomain', 'cc_attribute', 'cc_sharealike', 'cc_noncommercial', 'cc_nonderived'], default='cc_publicdomain', label='GIS: Copy rights'),
160
+ gr.inputs.Dropdown(['clipart', 'face', 'lineart', 'stock', 'photo', 'animated', 'imgTypeUndefined'], default='photo', label='GIS: Image type'),
161
+ gr.inputs.Dropdown(['color', 'gray', 'mono', 'trans', 'imgColorTypeUndefined'], default='color', label='GIS: Image color type'),
162
  ],
163
  theme="grass",
164
  outputs=[