taesiri commited on
Commit
0a8da33
1 Parent(s): cbda7ef

remove knn

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -191,18 +191,14 @@ def load_sample(data, current_index):
191
  image_id = data[current_index]["id"]
192
  qimage = data[current_index]["image"]
193
 
194
- neighbors_path = os.path.join(knn_cache_path, f"{image_id}.JPEG")
195
- neighbors_image = Image.open(neighbors_path).convert('RGB')
196
-
197
  labels = data[current_index]["correct_label"]
198
- return qimage, neighbors_image, labels
199
  # return qimage, neighbors_image, training_samples_image
200
 
201
 
202
  def update_app(decision, data, current_index, history, username):
203
  if current_index == -1:
204
  data = generate_dataset()
205
- nns = {}
206
 
207
  if current_index>=0 and current_index < NUMBER_OF_IMAGES-1:
208
  time_stamp = int(time.time())
@@ -233,20 +229,19 @@ def update_app(decision, data, current_index, history, username):
233
  os.remove(temp_filename)
234
 
235
  elif current_index == NUMBER_OF_IMAGES-1:
236
- return None, None, None, current_index, history, data, None, None
237
 
238
  current_index += 1
239
- qimage, neighbors_image, labels = load_sample(data, current_index)
240
  image_id = data[current_index]["id"]
241
  training_samples_image = get_training_samples(image_id)
242
  training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image]
243
- nns = label_dist_of_nns(image_id)
244
 
245
  # labels is a list of labels, conver it to a string
246
  labels = ", ".join(labels)
247
  label_plot = string_to_image(labels)
248
 
249
- return qimage, label_plot, neighbors_image, current_index, history, data, nns, training_samples_image
250
 
251
 
252
  newcss = '''
@@ -286,26 +281,26 @@ with gr.Blocks(css=newcss) as demo:
286
  with gr.Column():
287
  label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig')
288
  training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery")
289
- with gr.Column():
290
- gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)")
291
- nn_labels = gr.Label(label="NN-Labels")
292
- neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery")
293
 
294
  accept_btn.click(
295
  update_app,
296
  inputs=[accept_btn, data_gr, current_index, history, username],
297
- outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples]
298
  )
299
  myabe_btn.click(
300
  update_app,
301
  inputs=[myabe_btn, data_gr, current_index, history, username],
302
- outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples]
303
  )
304
 
305
  reject_btn.click(
306
  update_app,
307
  inputs=[reject_btn, data_gr, current_index, history, username],
308
- outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples]
309
  )
310
 
311
  demo.launch()
 
191
  image_id = data[current_index]["id"]
192
  qimage = data[current_index]["image"]
193
 
 
 
 
194
  labels = data[current_index]["correct_label"]
195
+ return qimage, labels
196
  # return qimage, neighbors_image, training_samples_image
197
 
198
 
199
  def update_app(decision, data, current_index, history, username):
200
  if current_index == -1:
201
  data = generate_dataset()
 
202
 
203
  if current_index>=0 and current_index < NUMBER_OF_IMAGES-1:
204
  time_stamp = int(time.time())
 
229
  os.remove(temp_filename)
230
 
231
  elif current_index == NUMBER_OF_IMAGES-1:
232
+ return None, None, current_index, history, data, None
233
 
234
  current_index += 1
235
+ qimage, labels = load_sample(data, current_index)
236
  image_id = data[current_index]["id"]
237
  training_samples_image = get_training_samples(image_id)
238
  training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image]
 
239
 
240
  # labels is a list of labels, conver it to a string
241
  labels = ", ".join(labels)
242
  label_plot = string_to_image(labels)
243
 
244
+ return qimage, label_plot, current_index, history, data, training_samples_image
245
 
246
 
247
  newcss = '''
 
281
  with gr.Column():
282
  label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig')
283
  training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery")
284
+ # with gr.Column():
285
+ # gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)")
286
+ # nn_labels = gr.Label(label="NN-Labels")
287
+ # neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery")
288
 
289
  accept_btn.click(
290
  update_app,
291
  inputs=[accept_btn, data_gr, current_index, history, username],
292
+ outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
293
  )
294
  myabe_btn.click(
295
  update_app,
296
  inputs=[myabe_btn, data_gr, current_index, history, username],
297
+ outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
298
  )
299
 
300
  reject_btn.click(
301
  update_app,
302
  inputs=[reject_btn, data_gr, current_index, history, username],
303
+ outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
304
  )
305
 
306
  demo.launch()