import os from io import BytesIO from multiprocessing import Pool, cpu_count from datasets import load_dataset from PIL import Image import gradio as gr import pandas as pd imagenet_hard_dataset = load_dataset("taesiri/imagenet-hard", split="validation") THUMBNAIL_PATH = "dataset/thumbnails" os.makedirs(THUMBNAIL_PATH, exist_ok=True) max_size = (480, 480) all_origins = set() all_labels = set() dataset_df = None def process_image(i): global all_origins image = imagenet_hard_dataset[i]["image"].convert("RGB") url_prefix = "https://imagenet-hard.taesiri.ai/" origin = imagenet_hard_dataset[i]["origin"] label = imagenet_hard_dataset[i]["english_label"] save_path = os.path.join(THUMBNAIL_PATH, origin) # make sure the folder exists os.makedirs(save_path, exist_ok=True) image_path = os.path.join(save_path, f"{i}.jpg") image.thumbnail(max_size, Image.LANCZOS) image.save(image_path, "JPEG", quality=100) url = url_prefix + image_path return { "preview": url, "filepath": image_path, "origin": imagenet_hard_dataset[i]["origin"], "labels": imagenet_hard_dataset[i]["english_label"], } # PREPROCESSING if os.path.exists("dataset.pkl"): dataset_df = pd.read_pickle("dataset.pkl") all_origins = set(dataset_df["origin"]) all_labels = set().union(*dataset_df["labels"]) else: with Pool(cpu_count()) as pool: samples_data = pool.map(process_image, range(len(imagenet_hard_dataset))) dataset_df = pd.DataFrame(samples_data) print(dataset_df) all_origins = set(dataset_df["origin"]) all_labels = set().union(*dataset_df["labels"]) # save dataframe on disk dataset_df.to_csv("dataset.csv") dataset_df.to_pickle("dataset.pkl") def get_slice(origin, label): global dataset_df if not origin and not label: filtered_df = dataset_df else: filtered_df = dataset_df[ (dataset_df["origin"] == origin if origin else True) & (dataset_df["labels"].apply(lambda x: label in x) if label else True) ] max_value = len(filtered_df) // 16 returned_values = [] start_index = 0 end_index = start_index + 16 slice_df = filtered_df.iloc[start_index:end_index] for row in slice_df.itertuples(): returned_values.append(gr.update(value=row.preview)) returned_values.append(gr.update(value=row.origin)) returned_values.append(gr.update(value=row.labels)) if len(returned_values) < 48: returned_values.extend([None] * (48 - len(returned_values))) filtered_df = gr.Dataframe(filtered_df, datatype="markdown") return filtered_df, gr.update(maximum=max_value, value=0), *returned_values def reset_filters_fn(): return gr.update(value=None), gr.update(value=None) def make_grid(grid_size): list_of_components = [] with gr.Row(): for row_counter in range(grid_size[0]): with gr.Column(): for col_counter in range(grid_size[1]): item_image = gr.Image() with gr.Accordion("Click for details", open=False): item_source = gr.Textbox(label="Source Dataset") item_labels = gr.Textbox(label="Labels") list_of_components.append(item_image) list_of_components.append(item_source) list_of_components.append(item_labels) return list_of_components def slider_upadte(slider, df): returned_values = [] start_index = (slider) * 16 end_index = start_index + 16 slice_df = df.iloc[start_index:end_index] for row in slice_df.itertuples(): returned_values.append(gr.update(value=row.preview)) returned_values.append(gr.update(value=row.origin)) returned_values.append(gr.update(value=row.labels)) if len(returned_values) < 48: returned_values.extend([None] * (48 - len(returned_values))) return returned_values with gr.Blocks() as demo: gr.Markdown("# ImageNet-Hard Browser") # add link to home page and dataset gr.HTML("") gr.HTML() gr.HTML( """
Project Home Page  |  Dataset
""" ) with gr.Row(): origin_dropdown = gr.Dropdown(all_origins, label="Origin") label_dropdown = gr.Dropdown(all_labels, label="Label") with gr.Row(): show_btn = gr.Button("Show") reset_filters = gr.Button("Reset Filters") preview_dataframe = gr.Dataframe(height=500, visible=False) gr.Markdown("## Preview") maximum_vale = len(dataset_df) // 16 preview_slider = gr.Slider(minimum=1, maximum=maximum_vale, step=1, value=1) all_components = make_grid((4, 4)) show_btn.click( fn=get_slice, inputs=[origin_dropdown, label_dropdown], outputs=[preview_dataframe, preview_slider, *all_components], ) reset_filters.click( fn=reset_filters_fn, inputs=[], outputs=[origin_dropdown, label_dropdown], ) preview_slider.change( fn=slider_upadte, inputs=[preview_slider, preview_dataframe], outputs=[*all_components], ) demo.launch()