Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import re | |
import sys | |
from datetime import datetime | |
from pathlib import Path | |
from typing import List, Optional | |
from uuid import uuid4 | |
import gradio as gr | |
import numpy as np | |
import requests | |
from datasets import load_dataset | |
from huggingface_hub import ( | |
CommitScheduler, | |
HfApi, | |
InferenceClient, | |
login, | |
snapshot_download, | |
) | |
from PIL import Image | |
from glob import glob | |
session_token = os.environ.get("SessionToken") | |
login(token=session_token, add_to_git_credential=True) | |
DEFAILT_USERNAME_MESSAGE = "You must be logged in befor starting to label images." | |
REPO_URL = "glitchbench/GlitchBenchReviewData" | |
DATASET_URL = "glitchbench/GlitchBench" | |
SUBMIT_MESSAGE = "Submit the description" | |
SKIP_MESSAGE = "Skip, I can not spot the bug!" | |
glitchbench_dataset = load_dataset(DATASET_URL)["validation"] | |
dataset_size = len(glitchbench_dataset) | |
# map id to index: | |
id_to_index = {x["id"]: i for i, x in enumerate(glitchbench_dataset)} | |
JSON_DATASET_DIR = Path("local_dataset") | |
JSON_DATASET_DATA_DIR = JSON_DATASET_DIR / "data" | |
JSON_DATASET_PATH = JSON_DATASET_DATA_DIR / f"labels-{uuid4()}.json" | |
if not JSON_DATASET_DIR.exists(): | |
JSON_DATASET_DIR.mkdir() | |
if not JSON_DATASET_DATA_DIR.exists(): | |
JSON_DATASET_DATA_DIR.mkdir() | |
print("Downloading the dataset") | |
print(REPO_URL) | |
snapshot_download( | |
repo_id=REPO_URL, | |
allow_patterns="*.json", | |
local_dir=JSON_DATASET_DIR, | |
use_auth_token=session_token, | |
repo_type="dataset", | |
) | |
scheduler = CommitScheduler( | |
repo_id=REPO_URL, | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="./", | |
every=1, | |
private=True, | |
) | |
def save_json(image_id: str, provided_description: str, username: str) -> None: | |
with scheduler.lock: | |
with JSON_DATASET_PATH.open("a") as f: | |
json.dump( | |
{ | |
"username": username, | |
"image_id": image_id, | |
"user_description": provided_description, | |
"datetime": datetime.now().isoformat(), | |
}, | |
f, | |
) | |
f.write("\n") | |
def set_username(profile: Optional[gr.OAuthProfile]) -> str: | |
if profile is None: | |
return DEFAILT_USERNAME_MESSAGE | |
return profile["preferred_username"] | |
def start_labeling(username_label): | |
if username_label == DEFAILT_USERNAME_MESSAGE: | |
raise gr.Error("Please login first, then click start labeling") | |
all_json_files = glob(str(JSON_DATASET_DATA_DIR / "*.json")) | |
# read json files and keep records related to the current user | |
all_user_records = [] | |
for json_file in all_json_files: | |
with open(json_file) as f: | |
for line in f: | |
record = json.loads(line) | |
if record["username"] == username_label: | |
all_user_records.append(record["image_id"]) | |
print(f"Found {len(all_user_records)} records for user {username_label}") | |
# go throught all images in the dataset and exlcude those that are already labeled by the user | |
remaining_indicies = set(range(dataset_size)) | |
solved_indices = [id_to_index[x] for x in all_user_records] | |
remaining_indicies = remaining_indicies - set(solved_indices) | |
print(f"Found {len(remaining_indicies)} remaining images for user {username_label}") | |
return list(remaining_indicies), gr.Button(interactive=False) | |
def show_random_sample(username_label, remaining_batch): | |
rindex = random.choice(remaining_batch) | |
remaining_batch.remove(rindex) | |
# get the image | |
image = glitchbench_dataset[rindex]["image"] | |
image_id = glitchbench_dataset[rindex]["id"] | |
return image, image_id, "", remaining_batch | |
def write_user_description(username_label, image_id, user_description, skip_or_submit): | |
if skip_or_submit == SKIP_MESSAGE: | |
provided_description = "N/A" | |
else: | |
provided_description = user_description | |
save_json(image_id, provided_description, username_label) | |
with gr.Blocks() as demo: | |
gr.Markdown("## GlitchBench Dataset Labeling Tool") | |
gr.Markdown("Help us to clean and label the GlitchBench dataset.") | |
with gr.Row(): | |
username_label = gr.Text(label="Username", interactive=False) | |
gr.LoginButton() | |
gr.LogoutButton() | |
start_button = gr.Button("Start Labeling") | |
username_label.attach_load_event(set_username, None) | |
glitch_image = gr.Image(label="Image") | |
glitch_image_id = gr.Textbox(label="Image ID", visible=False) | |
with gr.Row(): | |
user_description = gr.Textbox(lines=5, label="Description") | |
with gr.Column(): | |
submit_button = gr.Button(SUBMIT_MESSAGE) | |
Skip_btton = gr.Button(SKIP_MESSAGE) | |
remaining_batch = gr.State() | |
start_button.click( | |
start_labeling, inputs=[username_label], outputs=[remaining_batch, start_button] | |
).then( | |
show_random_sample, | |
inputs=[username_label, remaining_batch], | |
outputs=[glitch_image, glitch_image_id, user_description], | |
) | |
submit_button.click( | |
write_user_description, | |
inputs=[username_label, glitch_image_id, user_description, submit_button], | |
outputs=[], | |
).then( | |
show_random_sample, | |
inputs=[username_label, remaining_batch], | |
outputs=[glitch_image, glitch_image_id, user_description], | |
) | |
Skip_btton.click( | |
write_user_description, | |
inputs=[username_label, glitch_image_id, user_description, Skip_btton], | |
outputs=[], | |
).then( | |
show_random_sample, | |
inputs=[username_label, remaining_batch], | |
outputs=[glitch_image, glitch_image_id, user_description], | |
) | |
demo.launch() | |