import json import os import random import string import time import sys import datasets import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch import pickle from PIL import Image from torchvision import transforms from huggingface_hub import HfApi, login from torchvision.datasets import ImageFolder from glob import glob import gdown import torchvision import pandas as pd from huggingface_hub import HfApi, login, snapshot_download import matplotlib.pyplot as plt import numpy as np import csv csv.field_size_limit(sys.maxsize) np.random.seed(int(time.time())) with open('./imagenet_hard_nearest_indices.pkl', 'rb') as f: knn_results = pickle.load(f) with open("imagenet-labels.json") as f: wnid_to_label = json.load(f) with open('id_to_label.json', 'r') as f: id_to_labels = json.load(f) bad_items = open('./ex2.txt', 'r').read().split('\n') bad_items = [x.split('.')[0] for x in bad_items] bad_items = [int(x) for x in bad_items if x != ''] # download and extract folders gdown.cached_download( url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", path="./data.zip", quiet=False, md5="8666a9b361f6eea79878be6c09701def", ) # EXTRACT if needed if not os.path.exists("./imagenet_traning_samples") or not os.path.exists("./knn_cache_for_imagenet_hard"): torchvision.datasets.utils.extract_archive( from_path="data.zip", to_path="./", remove_finished=False, ) imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") def update_snapshot(): output_dir = snapshot_download( repo_id="taesiri/imagenet_hard_review_data", allow_patterns="*.json", repo_type="dataset" ) total_size = len(imagenet_hard) files = glob(f"{output_dir}/*.json") df = pd.DataFrame() columns = ["id", "user_id", "time", "decision"] rows = [] for file in files: with open(file) as f: data = json.load(f) tdf = [data[x] for x in columns] # add filename as a column rows.append(tdf) df = pd.DataFrame(rows, columns=columns) return df, total_size # df = update_snapshot() NUMBER_OF_IMAGES = 1000 # Function to sample 10 ids based on their usage count def sample_ids(df, total_size, sample_size): id_counts = df['id'].value_counts().to_dict() all_ids = bad_items for id in all_ids: if id not in id_counts: id_counts[id] = 0 weights = [id_counts[id] for id in all_ids] inverse_weights = [1 / (count + 1) for count in weights] normalized_weights = [w / sum(inverse_weights) for w in inverse_weights] sampled_ids = np.random.choice(all_ids, size=sample_size, replace=False, p=normalized_weights) return sampled_ids def generate_dataset(): df, total_size = update_snapshot() random_indices = sample_ids(df, total_size, NUMBER_OF_IMAGES) random_images = [imagenet_hard[int(i)]["image"] for i in random_indices] random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices] random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices] data = [] for i, image in enumerate(random_images): data.append( { "id": random_indices[i], "image": image, "correct_label": random_gt_labels[i], "original_id": int(random_indices[i]), } ) return data def string_to_image(text): text = text.replace('_', ' ').lower().replace(', ', '\n') # Create a blank white square image img = np.ones((220, 75, 3)) # Create a figure and axis object fig, ax = plt.subplots(figsize=(6, 2.25)) # Plot the blank white image ax.imshow(img, extent=[0, 1, 0, 1]) # Set the text in the center ax.text(0.5, 0.75, text, fontsize=18, ha='center', va='center') # Remove the axis labels and ticks ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) # Remove the axis spines for spine in ax.spines.values(): spine.set_visible(False) # Return the figure return fig def label_dist_of_nns(qid): with open('./trainingset_filenames.json', 'r') as f: trainingset_filenames = json.load(f) nns = knn_results[qid][:15] labels = [wnid_to_label[trainingset_filenames[f"{x}"]] for x in nns] label_counts = {x: labels.count(x) for x in set(labels)} # sort by count label_counts = {k: v for k, v in sorted(label_counts.items(), key=lambda item: item[1], reverse=True)} # percetage label_counts = {k: v/len(labels) for k, v in label_counts.items()} return label_counts from glob import glob all_samples = glob('./imagenet_traning_samples/*.JPEG') qid_to_sample = {int(x.split('/')[-1].split('.')[0].split('_')[0]): x for x in all_samples} def get_training_samples(qid): labels_id = imagenet_hard[int(qid)]['label'] samples = [qid_to_sample[x] for x in labels_id] return samples knn_cache_path = "knn_cache_for_imagenet_hard" imagenet_training_samples_path = "imagenet_traning_samples" def load_sample(data, current_index): image_id = data[current_index]["id"] qimage = data[current_index]["image"] neighbors_path = os.path.join(knn_cache_path, f"{image_id}.JPEG") neighbors_image = Image.open(neighbors_path).convert('RGB') labels = data[current_index]["correct_label"] return qimage, neighbors_image, labels # return qimage, neighbors_image, training_samples_image def update_app(decision, data, current_index, history, username): if current_index == -1: data = generate_dataset() nns = {} if current_index>=0 and current_index < NUMBER_OF_IMAGES-1: time_stamp = int(time.time()) image_id = data[current_index]["id"] # convert to percentage dicision_dict = { "id": int(image_id), "user_id": username, "time": time_stamp, "decision": decision, } # upload the decision to the server temp_filename = f"results_{username}_{time_stamp}.json" # convert decision_dict to json and save it on the disk with open(temp_filename, "w") as f: json.dump(dicision_dict, f) api = HfApi() api.upload_file( path_or_fileobj=temp_filename, path_in_repo=temp_filename, repo_id="taesiri/imagenet_hard_review_data", repo_type="dataset", ) os.remove(temp_filename) elif current_index == NUMBER_OF_IMAGES-1: return None, None, None, current_index, history, data, None, None current_index += 1 qimage, neighbors_image, labels = load_sample(data, current_index) image_id = data[current_index]["id"] training_samples_image = get_training_samples(image_id) training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image] nns = label_dist_of_nns(image_id) # labels is a list of labels, conver it to a string labels = ", ".join(labels) label_plot = string_to_image(labels) return qimage, label_plot, neighbors_image, current_index, history, data, nns, training_samples_image newcss = ''' #query_image{ height: auto !important; } #nn_gallery { height: auto !important; } #sample_gallery { height: auto !important; } ''' with gr.Blocks(css=newcss) as demo: data_gr = gr.State({}) current_index = gr.State(-1) history = gr.State({}) gr.Markdown("# Cleaning ImageNet-Hard!") random_str = "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(5) ) username = gr.Textbox(label="Username", value=f"user-{random_str}") with gr.Column(): with gr.Row(): accept_btn = gr.Button(value="Accept") myabe_btn = gr.Button(value="Not Sure!") reject_btn = gr.Button(value="Reject") with gr.Row(): query_image = gr.Image(type="pil", label="Query", elem_id="query_image") with gr.Column(): label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig') training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery") with gr.Column(): gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)") nn_labels = gr.Label(label="NN-Labels") neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery") accept_btn.click( update_app, inputs=[accept_btn, data_gr, current_index, history, username], outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] ) myabe_btn.click( update_app, inputs=[myabe_btn, data_gr, current_index, history, username], outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] ) reject_btn.click( update_app, inputs=[reject_btn, data_gr, current_index, history, username], outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] ) demo.launch()