taesiri commited on
Commit
7fe7653
1 Parent(s): c0e2709
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -19,8 +19,8 @@ import torchvision
19
  from huggingface_hub import HfApi, login, snapshot_download
20
  from PIL import Image
21
 
22
- session_token = os.environ.get("SessionToken")
23
- login(token=session_token)
24
 
25
  csv.field_size_limit(sys.maxsize)
26
 
@@ -35,7 +35,7 @@ with open("imagenet-labels.json") as f:
35
  with open("id_to_label.json", "r") as f:
36
  id_to_labels = json.load(f)
37
 
38
- imagenet_training_samples_path = "imagenet_traning_samples"
39
 
40
  bad_items = open("./ex2.txt", "r").read().split("\n")
41
  bad_items = [x.split(".")[0] for x in bad_items]
@@ -48,12 +48,12 @@ gdown.cached_download(
48
  url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip",
49
  path="./data.zip",
50
  quiet=False,
51
- md5="8666a9b361f6eea79878be6c09701def",
52
  )
53
 
54
  # EXTRACT if needed
55
 
56
- if not os.path.exists("./imagenet_traning_samples") or not os.path.exists(
57
  "./knn_cache_for_imagenet_hard"
58
  ):
59
  torchvision.datasets.utils.extract_archive(
@@ -106,18 +106,11 @@ def generate_dataset(username):
106
  if NUMBER_OF_IMAGES == 0:
107
  return []
108
 
109
- # random_indices = remaining
110
- # random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
111
- # random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
112
- # random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
113
-
114
  data = []
115
  for i, image in enumerate(remaining):
116
  data.append(
117
  {
118
  "id": remaining[i],
119
- # "correct_label": random_gt_labels[i],
120
- # "original_id": int(random_indices[i]),
121
  }
122
  )
123
  return data
@@ -141,7 +134,7 @@ def string_to_image(text):
141
  return fig
142
 
143
 
144
- all_samples = glob("./imagenet_traning_samples/*.JPEG")
145
  qid_to_sample = {
146
  int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
147
  }
@@ -367,13 +360,15 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
367
  reject_btn = gr.Button(value="Reject")
368
  with gr.Row():
369
  query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
370
- with gr.Column():
371
- label_plot = gr.Plot(
372
- label="Is this a correct label for this image?", type="fig"
373
- )
374
- training_samples = gr.Gallery(
375
- type="pil", label="Training samples", elem_id="sample_gallery"
376
- )
 
 
377
 
378
  accept_btn.click(
379
  update_app,
 
19
  from huggingface_hub import HfApi, login, snapshot_download
20
  from PIL import Image
21
 
22
+ # session_token = os.environ.get("SessionToken")
23
+ # login(token=session_token)
24
 
25
  csv.field_size_limit(sys.maxsize)
26
 
 
35
  with open("id_to_label.json", "r") as f:
36
  id_to_labels = json.load(f)
37
 
38
+ imagenet_training_samples_path = "imagenet_samples"
39
 
40
  bad_items = open("./ex2.txt", "r").read().split("\n")
41
  bad_items = [x.split(".")[0] for x in bad_items]
 
48
  url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip",
49
  path="./data.zip",
50
  quiet=False,
51
+ md5="ece2720fed664e71799f316a881d4324",
52
  )
53
 
54
  # EXTRACT if needed
55
 
56
+ if not os.path.exists("./imagenet_samples") or not os.path.exists(
57
  "./knn_cache_for_imagenet_hard"
58
  ):
59
  torchvision.datasets.utils.extract_archive(
 
106
  if NUMBER_OF_IMAGES == 0:
107
  return []
108
 
 
 
 
 
 
109
  data = []
110
  for i, image in enumerate(remaining):
111
  data.append(
112
  {
113
  "id": remaining[i],
 
 
114
  }
115
  )
116
  return data
 
134
  return fig
135
 
136
 
137
+ all_samples = glob("./imagenet_samples/*.JPEG")
138
  qid_to_sample = {
139
  int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
140
  }
 
360
  reject_btn = gr.Button(value="Reject")
361
  with gr.Row():
362
  query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
363
+ with gr.Row():
364
+ with gr.Column(scale=1):
365
+ label_plot = gr.Plot(
366
+ label="Is this a correct label for this image?", type="fig"
367
+ )
368
+ with gr.Column(scale=3):
369
+ training_samples = gr.Gallery(
370
+ type="pil", label="Training samples", elem_id="sample_gallery"
371
+ )
372
 
373
  accept_btn.click(
374
  update_app,