taesiri commited on
Commit
15d5b2e
1 Parent(s): 0a8da33
Files changed (1) hide show
  1. app.py +173 -104
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import csv
3
  import json
4
  import os
@@ -20,28 +19,30 @@ import torchvision
20
  from huggingface_hub import HfApi, login, snapshot_download
21
  from PIL import Image
22
 
23
-
24
- session_token = os.environ.get("SessionToken")
25
- login(token=session_token)
26
 
27
 
28
  csv.field_size_limit(sys.maxsize)
29
 
30
  np.random.seed(int(time.time()))
31
 
32
- with open('./imagenet_hard_nearest_indices.pkl', 'rb') as f:
33
- knn_results = pickle.load(f)
34
 
35
  with open("imagenet-labels.json") as f:
36
  wnid_to_label = json.load(f)
37
 
38
- with open('id_to_label.json', 'r') as f:
39
  id_to_labels = json.load(f)
40
 
 
41
 
42
- bad_items = open('./ex2.txt', 'r').read().split('\n')
43
- bad_items = [x.split('.')[0] for x in bad_items]
44
- bad_items = [int(x) for x in bad_items if x != '']
 
 
45
 
46
  # download and extract folders
47
 
@@ -54,7 +55,9 @@ gdown.cached_download(
54
 
55
  # EXTRACT if needed
56
 
57
- if not os.path.exists("./imagenet_traning_samples") or not os.path.exists("./knn_cache_for_imagenet_hard"):
 
 
58
  torchvision.datasets.utils.extract_archive(
59
  from_path="data.zip",
60
  to_path="./",
@@ -64,11 +67,12 @@ if not os.path.exists("./imagenet_traning_samples") or not os.path.exists("./knn
64
  imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation")
65
 
66
 
67
- def update_snapshot():
68
  output_dir = snapshot_download(
69
- repo_id="taesiri/imagenet_hard_review_data", allow_patterns="*.json", repo_type="dataset"
 
 
70
  )
71
- total_size = len(imagenet_hard)
72
  files = glob(f"{output_dir}/*.json")
73
 
74
  df = pd.DataFrame()
@@ -83,42 +87,35 @@ def update_snapshot():
83
  rows.append(tdf)
84
 
85
  df = pd.DataFrame(rows, columns=columns)
 
86
 
87
- return df, total_size
88
-
89
 
90
- # df = update_snapshot()
91
 
92
- NUMBER_OF_IMAGES = 1000
 
 
93
 
94
- # Function to sample 10 ids based on their usage count
95
- def sample_ids(df, total_size, sample_size):
96
- id_counts = df['id'].value_counts().to_dict()
97
- all_ids = bad_items
98
 
99
- for id in all_ids:
100
- if id not in id_counts:
101
- id_counts[id] = 0
 
 
 
 
102
 
103
- weights = [id_counts[id] for id in all_ids]
104
- inverse_weights = [1 / (count + 1) for count in weights]
105
- normalized_weights = [w / sum(inverse_weights) for w in inverse_weights]
106
-
107
- sampled_ids = np.random.choice(all_ids, size=sample_size, replace=False, p=normalized_weights)
108
- return sampled_ids
109
-
110
-
111
- def generate_dataset():
112
- df, total_size = update_snapshot()
113
- random_indices = sample_ids(df, total_size, NUMBER_OF_IMAGES)
114
  random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
115
  random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
116
- random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
117
 
118
  data = []
119
  for i, image in enumerate(random_images):
120
  data.append(
121
- {
122
  "id": random_indices[i],
123
  "image": image,
124
  "correct_label": random_gt_labels[i],
@@ -128,83 +125,86 @@ def generate_dataset():
128
  return data
129
 
130
 
131
-
132
  def string_to_image(text):
133
- text = text.replace('_', ' ').lower().replace(', ', '\n')
134
  # Create a blank white square image
135
  img = np.ones((220, 75, 3))
136
 
137
- # Create a figure and axis object
138
  fig, ax = plt.subplots(figsize=(6, 2.25))
139
-
140
- # Plot the blank white image
141
  ax.imshow(img, extent=[0, 1, 0, 1])
142
-
143
- # Set the text in the center
144
- ax.text(0.5, 0.75, text, fontsize=18, ha='center', va='center')
145
-
146
- # Remove the axis labels and ticks
147
  ax.set_xticks([])
148
  ax.set_yticks([])
149
  ax.set_xticklabels([])
150
  ax.set_yticklabels([])
151
-
152
- # Remove the axis spines
153
  for spine in ax.spines.values():
154
  spine.set_visible(False)
155
 
156
- # Return the figure
157
  return fig
158
 
159
 
 
 
 
 
160
 
161
- def label_dist_of_nns(qid):
162
-
163
- with open('./trainingset_filenames.json', 'r') as f:
164
- trainingset_filenames = json.load(f)
165
-
166
- nns = knn_results[qid][:15]
167
- labels = [wnid_to_label[trainingset_filenames[f"{x}"]] for x in nns]
168
- label_counts = {x: labels.count(x) for x in set(labels)}
169
- # sort by count
170
- label_counts = {k: v for k, v in sorted(label_counts.items(), key=lambda item: item[1], reverse=True)}
171
- # percetage
172
- label_counts = {k: v/len(labels) for k, v in label_counts.items()}
173
- return label_counts
174
-
175
-
176
- from glob import glob
177
 
178
- all_samples = glob('./imagenet_traning_samples/*.JPEG')
179
- qid_to_sample = {int(x.split('/')[-1].split('.')[0].split('_')[0]): x for x in all_samples}
180
 
181
  def get_training_samples(qid):
182
- labels_id = imagenet_hard[int(qid)]['label']
183
  samples = [qid_to_sample[x] for x in labels_id]
184
  return samples
185
 
186
 
187
- knn_cache_path = "knn_cache_for_imagenet_hard"
188
- imagenet_training_samples_path = "imagenet_traning_samples"
189
-
190
  def load_sample(data, current_index):
191
  image_id = data[current_index]["id"]
192
  qimage = data[current_index]["image"]
193
-
194
  labels = data[current_index]["correct_label"]
195
  return qimage, labels
196
- # return qimage, neighbors_image, training_samples_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
 
199
  def update_app(decision, data, current_index, history, username):
 
200
  if current_index == -1:
201
- data = generate_dataset()
202
-
203
- if current_index>=0 and current_index < NUMBER_OF_IMAGES-1:
204
  time_stamp = int(time.time())
205
 
206
  image_id = data[current_index]["id"]
207
- # convert to percentage
208
  dicision_dict = {
209
  "id": int(image_id),
210
  "user_id": username,
@@ -228,23 +228,56 @@ def update_app(decision, data, current_index, history, username):
228
 
229
  os.remove(temp_filename)
230
 
231
- elif current_index == NUMBER_OF_IMAGES-1:
232
- return None, None, current_index, history, data, None
 
233
 
234
- current_index += 1
235
- qimage, labels = load_sample(data, current_index)
236
- image_id = data[current_index]["id"]
237
- training_samples_image = get_training_samples(image_id)
238
- training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image]
239
 
240
- # labels is a list of labels, conver it to a string
241
- labels = ", ".join(labels)
242
- label_plot = string_to_image(labels)
 
 
 
 
 
243
 
244
- return qimage, label_plot, current_index, history, data, training_samples_image
 
 
 
 
245
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- newcss = '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  #query_image{
249
  height: auto !important;
250
  }
@@ -256,51 +289,87 @@ newcss = '''
256
  #sample_gallery {
257
  height: auto !important;
258
  }
259
- '''
260
 
261
  with gr.Blocks(css=newcss) as demo:
262
  data_gr = gr.State({})
263
  current_index = gr.State(-1)
264
  history = gr.State({})
265
-
266
  gr.Markdown("# Cleaning ImageNet-Hard!")
267
 
268
  random_str = "".join(
269
  random.choice(string.ascii_lowercase + string.digits) for _ in range(5)
270
  )
271
 
272
- username = gr.Textbox(label="Username", value=f"user-{random_str}")
 
 
273
 
274
  with gr.Column():
275
  with gr.Row():
276
  accept_btn = gr.Button(value="Accept")
277
  myabe_btn = gr.Button(value="Not Sure!")
278
  reject_btn = gr.Button(value="Reject")
279
- with gr.Row():
280
  query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
281
  with gr.Column():
282
- label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig')
283
- training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery")
284
- # with gr.Column():
285
- # gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)")
286
- # nn_labels = gr.Label(label="NN-Labels")
287
- # neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery")
288
 
289
  accept_btn.click(
290
  update_app,
291
  inputs=[accept_btn, data_gr, current_index, history, username],
292
- outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
 
 
 
 
 
 
 
293
  )
294
  myabe_btn.click(
295
  update_app,
296
  inputs=[myabe_btn, data_gr, current_index, history, username],
297
- outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
 
 
 
 
 
 
 
298
  )
299
 
300
  reject_btn.click(
301
  update_app,
302
  inputs=[reject_btn, data_gr, current_index, history, username],
303
- outputs=[query_image, label_plot, current_index, history, data_gr, training_samples]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
 
306
  demo.launch()
 
 
1
  import csv
2
  import json
3
  import os
 
19
  from huggingface_hub import HfApi, login, snapshot_download
20
  from PIL import Image
21
 
22
+ # session_token = os.environ.get("SessionToken")
23
+ # login(token=session_token)
 
24
 
25
 
26
  csv.field_size_limit(sys.maxsize)
27
 
28
  np.random.seed(int(time.time()))
29
 
30
+ with open("./imagenet_hard_nearest_indices.pkl", "rb") as f:
31
+ knn_results = pickle.load(f)
32
 
33
  with open("imagenet-labels.json") as f:
34
  wnid_to_label = json.load(f)
35
 
36
+ with open("id_to_label.json", "r") as f:
37
  id_to_labels = json.load(f)
38
 
39
+ imagenet_training_samples_path = "imagenet_traning_samples"
40
 
41
+ bad_items = open("./ex2.txt", "r").read().split("\n")
42
+ bad_items = [x.split(".")[0] for x in bad_items]
43
+ bad_items = [int(x) for x in bad_items if x != ""]
44
+
45
+ NUMBER_OF_IMAGES = 100 # len(bad_items)
46
 
47
  # download and extract folders
48
 
 
55
 
56
  # EXTRACT if needed
57
 
58
+ if not os.path.exists("./imagenet_traning_samples") or not os.path.exists(
59
+ "./knn_cache_for_imagenet_hard"
60
+ ):
61
  torchvision.datasets.utils.extract_archive(
62
  from_path="data.zip",
63
  to_path="./",
 
67
  imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation")
68
 
69
 
70
+ def update_snapshot(username):
71
  output_dir = snapshot_download(
72
+ repo_id="taesiri/imagenet_hard_review_data",
73
+ allow_patterns="*.json",
74
+ repo_type="dataset",
75
  )
 
76
  files = glob(f"{output_dir}/*.json")
77
 
78
  df = pd.DataFrame()
 
87
  rows.append(tdf)
88
 
89
  df = pd.DataFrame(rows, columns=columns)
90
+ df = df[df["user_id"] == username]
91
 
92
+ return df
 
93
 
 
94
 
95
+ def generate_dataset(username):
96
+ global NUMBER_OF_IMAGES
97
+ df = update_snapshot(username)
98
 
99
+ all_images = set(bad_items)
100
+ answered = set(df.id)
101
+ remaining = list(all_images - answered)
 
102
 
103
+ if len(remaining) < NUMBER_OF_IMAGES and len(remaining) > 0:
104
+ NUMBER_OF_IMAGES = len(remaining)
105
+ random_indices = list(remaining)
106
+ elif len(remaining) == 0:
107
+ return []
108
+ else:
109
+ random_indices = np.random.choice(remaining, NUMBER_OF_IMAGES, replace=False)
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
112
  random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
113
+ random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
114
 
115
  data = []
116
  for i, image in enumerate(random_images):
117
  data.append(
118
+ {
119
  "id": random_indices[i],
120
  "image": image,
121
  "correct_label": random_gt_labels[i],
 
125
  return data
126
 
127
 
 
128
  def string_to_image(text):
129
+ text = text.replace("_", " ").lower().replace(", ", "\n")
130
  # Create a blank white square image
131
  img = np.ones((220, 75, 3))
132
 
 
133
  fig, ax = plt.subplots(figsize=(6, 2.25))
 
 
134
  ax.imshow(img, extent=[0, 1, 0, 1])
135
+ ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center")
 
 
 
 
136
  ax.set_xticks([])
137
  ax.set_yticks([])
138
  ax.set_xticklabels([])
139
  ax.set_yticklabels([])
 
 
140
  for spine in ax.spines.values():
141
  spine.set_visible(False)
142
 
 
143
  return fig
144
 
145
 
146
+ all_samples = glob("./imagenet_traning_samples/*.JPEG")
147
+ qid_to_sample = {
148
+ int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
149
+ }
150
 
151
+ # user-e3z5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
 
153
 
154
  def get_training_samples(qid):
155
+ labels_id = imagenet_hard[int(qid)]["label"]
156
  samples = [qid_to_sample[x] for x in labels_id]
157
  return samples
158
 
159
 
 
 
 
160
  def load_sample(data, current_index):
161
  image_id = data[current_index]["id"]
162
  qimage = data[current_index]["image"]
163
+
164
  labels = data[current_index]["correct_label"]
165
  return qimage, labels
166
+
167
+
168
+ def preprocessing(data, current_index, history, username):
169
+ data = generate_dataset(username)
170
+
171
+ if len(data) == 0:
172
+ fake_plot = string_to_image("No more images to review")
173
+ empty_image = Image.new("RGB", (224, 224))
174
+ return (
175
+ empty_image,
176
+ fake_plot,
177
+ current_index,
178
+ history,
179
+ data,
180
+ None,
181
+ )
182
+
183
+ current_index = 0
184
+ qimage, labels = load_sample(data, current_index)
185
+ image_id = data[current_index]["id"]
186
+ training_samples_image = get_training_samples(image_id)
187
+ training_samples_image = [
188
+ Image.open(x).convert("RGB") for x in training_samples_image
189
+ ]
190
+
191
+ # labels is a list of labels, conver it to a string
192
+ labels = ", ".join(labels)
193
+ label_plot = string_to_image(labels)
194
+
195
+ return qimage, label_plot, current_index, history, data, training_samples_image
196
 
197
 
198
  def update_app(decision, data, current_index, history, username):
199
+ global NUMBER_OF_IMAGES
200
  if current_index == -1:
201
+ return
202
+
203
+ if current_index == NUMBER_OF_IMAGES - 1:
204
  time_stamp = int(time.time())
205
 
206
  image_id = data[current_index]["id"]
207
+ # convert to percentage
208
  dicision_dict = {
209
  "id": int(image_id),
210
  "user_id": username,
 
228
 
229
  os.remove(temp_filename)
230
 
231
+ fake_plot = string_to_image("Thank you for your time!")
232
+ empty_image = Image.new("RGB", (224, 224))
233
+ return empty_image, fake_plot, current_index, history, data, None
234
 
235
+ if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
236
+ time_stamp = int(time.time())
 
 
 
237
 
238
+ image_id = data[current_index]["id"]
239
+ # convert to percentage
240
+ dicision_dict = {
241
+ "id": int(image_id),
242
+ "user_id": username,
243
+ "time": time_stamp,
244
+ "decision": decision,
245
+ }
246
 
247
+ # upload the decision to the server
248
+ temp_filename = f"results_{username}_{time_stamp}.json"
249
+ # convert decision_dict to json and save it on the disk
250
+ with open(temp_filename, "w") as f:
251
+ json.dump(dicision_dict, f)
252
 
253
+ api = HfApi()
254
+ api.upload_file(
255
+ path_or_fileobj=temp_filename,
256
+ path_in_repo=temp_filename,
257
+ repo_id="taesiri/imagenet_hard_review_data",
258
+ repo_type="dataset",
259
+ )
260
+
261
+ os.remove(temp_filename)
262
+
263
+ # Load the Next Image
264
 
265
+ current_index += 1
266
+ qimage, labels = load_sample(data, current_index)
267
+ image_id = data[current_index]["id"]
268
+ training_samples_image = get_training_samples(image_id)
269
+ training_samples_image = [
270
+ Image.open(x).convert("RGB") for x in training_samples_image
271
+ ]
272
+
273
+ # labels is a list of labels, conver it to a string
274
+ labels = ", ".join(labels)
275
+ label_plot = string_to_image(labels)
276
+
277
+ return qimage, label_plot, current_index, history, data, training_samples_image
278
+
279
+
280
+ newcss = """
281
  #query_image{
282
  height: auto !important;
283
  }
 
289
  #sample_gallery {
290
  height: auto !important;
291
  }
292
+ """
293
 
294
  with gr.Blocks(css=newcss) as demo:
295
  data_gr = gr.State({})
296
  current_index = gr.State(-1)
297
  history = gr.State({})
298
+
299
  gr.Markdown("# Cleaning ImageNet-Hard!")
300
 
301
  random_str = "".join(
302
  random.choice(string.ascii_lowercase + string.digits) for _ in range(5)
303
  )
304
 
305
+ with gr.Row():
306
+ username = gr.Textbox(label="Username", value=f"user-{random_str}")
307
+ prepare_btn = gr.Button(value="Load Samples")
308
 
309
  with gr.Column():
310
  with gr.Row():
311
  accept_btn = gr.Button(value="Accept")
312
  myabe_btn = gr.Button(value="Not Sure!")
313
  reject_btn = gr.Button(value="Reject")
314
+ with gr.Row():
315
  query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
316
  with gr.Column():
317
+ label_plot = gr.Plot(
318
+ label="Is this a correct label for this image?", type="fig"
319
+ )
320
+ training_samples = gr.Gallery(
321
+ type="pil", label="Training samples", elem_id="sample_gallery"
322
+ )
323
 
324
  accept_btn.click(
325
  update_app,
326
  inputs=[accept_btn, data_gr, current_index, history, username],
327
+ outputs=[
328
+ query_image,
329
+ label_plot,
330
+ current_index,
331
+ history,
332
+ data_gr,
333
+ training_samples,
334
+ ],
335
  )
336
  myabe_btn.click(
337
  update_app,
338
  inputs=[myabe_btn, data_gr, current_index, history, username],
339
+ outputs=[
340
+ query_image,
341
+ label_plot,
342
+ current_index,
343
+ history,
344
+ data_gr,
345
+ training_samples,
346
+ ],
347
  )
348
 
349
  reject_btn.click(
350
  update_app,
351
  inputs=[reject_btn, data_gr, current_index, history, username],
352
+ outputs=[
353
+ query_image,
354
+ label_plot,
355
+ current_index,
356
+ history,
357
+ data_gr,
358
+ training_samples,
359
+ ],
360
+ )
361
+
362
+ prepare_btn.click(
363
+ preprocessing,
364
+ inputs=[data_gr, current_index, history, username],
365
+ outputs=[
366
+ query_image,
367
+ label_plot,
368
+ current_index,
369
+ history,
370
+ data_gr,
371
+ training_samples,
372
+ ],
373
  )
374
 
375
  demo.launch()