taesiri's picture
backup
e859cf6
raw
history blame
11 kB
import csv
import json
import os
import pickle
import random
import string
import sys
import time
from glob import glob
import datasets
import gdown
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision
from huggingface_hub import HfApi, login, snapshot_download
from PIL import Image
session_token = os.environ.get("SessionToken")
login(token=session_token)
csv.field_size_limit(sys.maxsize)
np.random.seed(int(time.time()))
with open("./imagenet_hard_nearest_indices.pkl", "rb") as f:
knn_results = pickle.load(f)
with open("imagenet-labels.json") as f:
wnid_to_label = json.load(f)
with open("id_to_label.json", "r") as f:
id_to_labels = json.load(f)
imagenet_training_samples_path = "imagenet_traning_samples"
bad_items = open("./ex2.txt", "r").read().split("\n")
bad_items = [x.split(".")[0] for x in bad_items]
bad_items = [int(x) for x in bad_items if x != ""]
NUMBER_OF_IMAGES = len(bad_items)
gdown.cached_download(
url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip",
path="./data.zip",
quiet=False,
md5="8666a9b361f6eea79878be6c09701def",
)
# EXTRACT if needed
if not os.path.exists("./imagenet_traning_samples") or not os.path.exists(
"./knn_cache_for_imagenet_hard"
):
torchvision.datasets.utils.extract_archive(
from_path="data.zip",
to_path="./",
remove_finished=False,
)
imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation")
def update_snapshot(username):
output_dir = snapshot_download(
repo_id="taesiri/imagenet_hard_review_data",
allow_patterns="*.json",
repo_type="dataset",
)
files = glob(f"{output_dir}/*.json")
df = pd.DataFrame()
columns = ["id", "user_id", "time", "decision"]
rows = []
for file in files:
with open(file) as f:
data = json.load(f)
tdf = [data[x] for x in columns]
rows.append(tdf)
df = pd.DataFrame(rows, columns=columns)
df = df[df["user_id"] == username]
return df
def generate_dataset(username):
global NUMBER_OF_IMAGES
df = update_snapshot(username)
all_images = set(bad_items)
answered = set(df.id)
remaining = list(all_images - answered)
# shuffle remaining
random.shuffle(remaining)
NUMBER_OF_IMAGES = len(bad_items)
if NUMBER_OF_IMAGES == 0:
return []
random_indices = remaining
random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
data = []
for i, image in enumerate(random_images):
data.append(
{
"id": random_indices[i],
"image": image,
"correct_label": random_gt_labels[i],
"original_id": int(random_indices[i]),
}
)
return data
def string_to_image(text):
text = text.replace("_", " ").lower().replace(", ", "\n")
# Create a blank white square image
img = np.ones((220, 75, 3))
fig, ax = plt.subplots(figsize=(6, 2.25))
ax.imshow(img, extent=[0, 1, 0, 1])
ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center")
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
spine.set_visible(False)
return fig
all_samples = glob("./imagenet_traning_samples/*.JPEG")
qid_to_sample = {
int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
}
def get_training_samples(qid):
labels_id = imagenet_hard[int(qid)]["label"]
samples = [qid_to_sample[x] for x in labels_id]
return samples
def load_sample(data, current_index):
image_id = data[current_index]["id"]
qimage = data[current_index]["image"]
labels = data[current_index]["correct_label"]
return qimage, labels
def preprocessing(data, current_index, history, username):
data = generate_dataset(username)
if len(data) == 0:
fake_plot = string_to_image("No more images to review")
empty_image = Image.new("RGB", (224, 224))
return (
empty_image,
fake_plot,
current_index,
history,
data,
None,
)
current_index = 0
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
return qimage, label_plot, current_index, history, data, training_samples_image
def update_app(decision, data, current_index, history, username):
global NUMBER_OF_IMAGES
if current_index == -1:
fake_plot = string_to_image("Please Enter your username and load samples")
empty_image = Image.new("RGB", (224, 224))
return empty_image, fake_plot, current_index, history, data, None
if current_index == NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
fake_plot = string_to_image("Thank you for your time!")
empty_image = Image.new("RGB", (224, 224))
return empty_image, fake_plot, current_index, history, data, None
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
# Load the Next Image
current_index += 1
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
return qimage, label_plot, current_index, history, data, training_samples_image
newcss = """
#query_image{
height: auto !important;
}
#nn_gallery {
height: auto !important;
}
#sample_gallery {
height: auto !important;
}
"""
with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
data_gr = gr.State({})
current_index = gr.State(-1)
history = gr.State({})
gr.Markdown("# Help Us to Clean `ImageNet-Hard`!")
gr.Markdown("## Instructions")
gr.Markdown(
"Please enter your username and press `Load Samples`. The loading process might take up to a minute. Once the loading is done, you can start reviewing the samples."
)
gr.Markdown(
"""For each image, please select one of the following options: `Accept`, `Not Sure!`, `Reject`.
- If you think any of the labels are correct, please select `Accept`.
- If you think none of the labels matching the image, please select `Reject`.
- If you are not sure about the label, please select `Not Sure!`.
You can refer to `Training samples` if you are not sure about the target label.
"""
)
random_str = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(5)
)
with gr.Column():
username = gr.Textbox(label="Username", value=f"user-{random_str}")
prepare_btn = gr.Button(value="Load Samples")
with gr.Column():
with gr.Row():
accept_btn = gr.Button(value="Accept")
myabe_btn = gr.Button(value="Not Sure!")
reject_btn = gr.Button(value="Reject")
with gr.Row():
query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
with gr.Column():
label_plot = gr.Plot(
label="Is this a correct label for this image?", type="fig"
)
training_samples = gr.Gallery(
type="pil", label="Training samples", elem_id="sample_gallery"
)
accept_btn.click(
update_app,
inputs=[accept_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
myabe_btn.click(
update_app,
inputs=[myabe_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
reject_btn.click(
update_app,
inputs=[reject_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
prepare_btn.click(
preprocessing,
inputs=[data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
demo.launch()