taesiri's picture
fix
c53518e
import csv
import json
import os
import pickle
import random
import string
import sys
import time
from glob import glob
import datasets
import gdown
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision
from huggingface_hub import HfApi, login, snapshot_download
from PIL import Image
import re
from fnmatch import translate
session_token = os.environ.get("SessionToken")
login(token=session_token)
csv.field_size_limit(sys.maxsize)
np.random.seed(int(time.time()))
with open("./imagenet_hard_nearest_indices.pkl", "rb") as f:
knn_results = pickle.load(f)
with open("imagenet-labels.json") as f:
wnid_to_label = json.load(f)
with open("id_to_label.json", "r") as f:
id_to_labels = json.load(f)
imagenet_training_samples_path = "imagenet_samples"
bad_items = open("./ex2.txt", "r").read().split("\n")
bad_items = [x.split(".")[0] for x in bad_items]
bad_items = [int(x) for x in bad_items if x != ""]
NUMBER_OF_IMAGES = len(bad_items)
gdown.cached_download(
url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip",
path="./data.zip",
quiet=False,
md5="ece2720fed664e71799f316a881d4324",
)
# EXTRACT if needed
if not os.path.exists("./imagenet_samples") or not os.path.exists(
"./knn_cache_for_imagenet_hard"
):
torchvision.datasets.utils.extract_archive(
from_path="data.zip",
to_path="./",
remove_finished=False,
)
imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation")
def update_snapshot(username):
escaped_username = re.escape(username)
pattern = f"*{escaped_username}*.json"
output_dir = snapshot_download(
repo_id="taesiri/imagenet_hard_review_data",
allow_patterns=pattern,
repo_type="dataset",
)
files = glob(f"{output_dir}/*.json")
df = pd.DataFrame()
columns = ["id", "user_id", "time", "decision"]
rows = []
for file in files:
with open(file) as f:
data = json.load(f)
tdf = [data[x] for x in columns]
rows.append(tdf)
df = pd.DataFrame(rows, columns=columns)
df = df[df["user_id"] == username]
return df
def generate_dataset(username):
global NUMBER_OF_IMAGES
df = update_snapshot(username)
all_images = set(bad_items)
answered = set(df.id)
remaining = list(all_images - answered)
# shuffle remaining
random.shuffle(remaining)
NUMBER_OF_IMAGES = len(bad_items)
print(f"NUMBER_OF_IMAGES: {NUMBER_OF_IMAGES}")
print(f"Remaining: {len(remaining)}")
if NUMBER_OF_IMAGES == 0:
return []
data = []
for i, image in enumerate(remaining):
data.append(
{
"id": remaining[i],
}
)
return data
def string_to_image(text):
text = text.replace("_", " ").lower().replace(", ", "\n")
# Create a blank white square image
img = np.ones((220, 75, 3))
fig, ax = plt.subplots(figsize=(6, 2.25))
ax.imshow(img, extent=[0, 1, 0, 1])
ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center")
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
spine.set_visible(False)
return fig
all_samples = glob("./imagenet_samples/*.JPEG")
qid_to_sample = {
int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
}
def get_training_samples(qid):
labels_id = imagenet_hard[int(qid)]["label"]
samples = [qid_to_sample[x] for x in labels_id]
return samples
def load_sample(data, current_index):
image_id = data[current_index]["id"]
qimage = imagenet_hard[int(image_id)]["image"]
# labels = data[current_index]["correct_label"]
labels = imagenet_hard[int(image_id)]["english_label"]
# print(f"Image ID: {image_id}")
# print(f"Labels: {labels}")
return qimage, labels
def preprocessing(data, current_index, history, username):
data = generate_dataset(username)
remaining_images = len(data)
labeled_images = len(bad_items) - remaining_images
if remaining_images == 0:
fake_plot = string_to_image("No more images to review")
empty_image = Image.new("RGB", (224, 224))
return (
empty_image,
fake_plot,
current_index,
history,
data,
None,
labeled_images,
)
current_index = 0
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
return (
qimage,
label_plot,
current_index,
history,
data,
training_samples_image,
labeled_images,
)
def update_app(decision, data, current_index, history, username):
global NUMBER_OF_IMAGES
if current_index == -1:
fake_plot = string_to_image("Please Enter your username and load samples")
empty_image = Image.new("RGB", (224, 224))
return empty_image, fake_plot, current_index, history, data, None, 0
if current_index == NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
fake_plot = string_to_image("Thank you for your time!")
empty_image = Image.new("RGB", (224, 224))
remaining_images = len(data)
labeled_images = (len(bad_items) - remaining_images) + current_index
return (
empty_image,
fake_plot,
current_index,
history,
data,
None,
labeled_images + 1,
)
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
# Load the Next Image
current_index += 1
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
remaining_images = len(data)
labeled_images = (len(bad_items) - remaining_images) + current_index
return (
qimage,
label_plot,
current_index,
history,
data,
training_samples_image,
labeled_images,
)
newcss = """
#query_image{
}
#nn_gallery {
height: auto !important;
}
#sample_gallery {
height: auto !important;
}
/* Set display to flex for the parent element */
.svelte-parentrowclass {
display: flex;
}
/* Set the flex-grow property for the children elements */
.svelte-parentrowclass > #query_image {
min-width: min(400px, 40%);
flex : 1;
flex-grow: 0; !important;
border-style: solid;
height: auto !important;
}
.svelte-parentrowclass > .svelte-rightcolumn {
flex: 2;
flex-grow: 0; !important;
min-width: min(600px, 60%);
}
"""
with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
data_gr = gr.State({})
current_index = gr.State(-1)
history = gr.State({})
gr.Markdown("# Help Us to Clean `ImageNet-Hard`!")
gr.Markdown("## Instructions")
gr.Markdown(
"Please enter your username and press `Load Samples`. The loading process might take up to a minute. Once the loading is done, you can start reviewing the samples."
)
gr.Markdown(
"""For each image, please select one of the following options: `Accept`, `Not Sure!`, `Reject`.
- If you think any of the labels are correct, please select `Accept`.
- If you think none of the labels matching the image, please select `Reject`.
- If you are not sure about the label, please select `Not Sure!`.
You can refer to `Training samples` if you are not sure about the target label.
"""
)
random_str = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(5)
)
with gr.Column():
with gr.Row():
username = gr.Textbox(label="Username", value=f"user-{random_str}")
labeled_images = gr.Textbox(label="Labeled Images", value="0")
total_images = gr.Textbox(label="Total Images", value=len(bad_items))
prepare_btn = gr.Button(value="Load Samples")
with gr.Column():
with gr.Row():
accept_btn = gr.Button(value="Accept")
myabe_btn = gr.Button(value="Not Sure!")
reject_btn = gr.Button(value="Reject")
with gr.Row(elem_id="parent_row", elem_classes="svelte-parentrowclass"):
query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
with gr.Column(
elem_id="samples_col",
elem_classes="svelte-rightcolumn",
):
label_plot = gr.Plot(
label="Is this a correct label for this image?", type="fig"
)
training_samples = gr.Gallery(
type="pil", label="Training samples", elem_id="sample_gallery"
)
accept_btn.click(
update_app,
inputs=[accept_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
labeled_images,
],
)
myabe_btn.click(
update_app,
inputs=[myabe_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
labeled_images,
],
)
reject_btn.click(
update_app,
inputs=[reject_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
labeled_images,
],
)
prepare_btn.click(
preprocessing,
inputs=[data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
labeled_images,
],
)
demo.launch(debug=False)