glitchbench commited on
Commit
248aa67
1 Parent(s): 9ed9f30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ import sys
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
+ from uuid import uuid4
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import requests
14
+ from datasets import load_dataset
15
+ from huggingface_hub import (
16
+ CommitScheduler,
17
+ HfApi,
18
+ InferenceClient,
19
+ login,
20
+ snapshot_download,
21
+ )
22
+ from PIL import Image
23
+ from glob import glob
24
+
25
+
26
+ session_token = os.environ.get("SessionToken")
27
+ login(token=session_token)
28
+
29
+ DEFAILT_USERNAME_MESSAGE = "You must be logged in befor starting to label images."
30
+ REPO_URL = "glitchbench/GlitchBenchReviewData"
31
+ DATASET_URL = "glitchbench/GlitchBench"
32
+
33
+ SUBMIT_MESSAGE = "Submit the description"
34
+ SKIP_MESSAGE = "Skip, I can not spot the bug!"
35
+
36
+ glitchbench_dataset = load_dataset(DATASET_URL)["validation"]
37
+ dataset_size = len(glitchbench_dataset)
38
+
39
+ # map id to index:
40
+ id_to_index = {x["id"]: i for i, x in enumerate(glitchbench_dataset)}
41
+
42
+ JSON_DATASET_DIR = Path("local_dataset")
43
+ JSON_DATASET_DATA_DIR = JSON_DATASET_DIR / "data"
44
+ JSON_DATASET_PATH = JSON_DATASET_DATA_DIR / f"labels-{uuid4()}.json"
45
+
46
+
47
+ if not JSON_DATASET_DIR.exists():
48
+ JSON_DATASET_DIR.mkdir()
49
+
50
+ if not JSON_DATASET_DATA_DIR.exists():
51
+ JSON_DATASET_DATA_DIR.mkdir()
52
+
53
+ print("Downloading the dataset")
54
+ print(REPO_URL)
55
+
56
+ snapshot_download(
57
+ repo_id=REPO_URL,
58
+ allow_patterns="*.json",
59
+ local_dir=JSON_DATASET_DIR,
60
+ use_auth_token=session_token,
61
+ repo_type="dataset",
62
+ )
63
+
64
+ scheduler = CommitScheduler(
65
+ repo_id=REPO_URL,
66
+ repo_type="dataset",
67
+ folder_path=JSON_DATASET_DIR,
68
+ path_in_repo="./",
69
+ every=1,
70
+ private=True,
71
+ )
72
+
73
+
74
+ def save_json(image_id: str, provided_description: str, username: str) -> None:
75
+ with scheduler.lock:
76
+ with JSON_DATASET_PATH.open("a") as f:
77
+ json.dump(
78
+ {
79
+ "username": username,
80
+ "image_id": image_id,
81
+ "user_description": provided_description,
82
+ "datetime": datetime.now().isoformat(),
83
+ },
84
+ f,
85
+ )
86
+ f.write("\n")
87
+
88
+
89
+ def set_username(profile: Optional[gr.OAuthProfile]) -> str:
90
+ if profile is None:
91
+ return DEFAILT_USERNAME_MESSAGE
92
+ return profile["preferred_username"]
93
+
94
+
95
+ def start_labeling(username_label):
96
+ if username_label == DEFAILT_USERNAME_MESSAGE:
97
+ raise gr.Error("Please login first, then click start labeling")
98
+
99
+ all_json_files = glob(str(JSON_DATASET_DATA_DIR / "*.json"))
100
+ # read json files and keep records related to the current user
101
+ all_user_records = []
102
+
103
+ for json_file in all_json_files:
104
+ with open(json_file) as f:
105
+ for line in f:
106
+ record = json.loads(line)
107
+ if record["username"] == username_label:
108
+ all_user_records.append(record["image_id"])
109
+
110
+ print(f"Found {len(all_user_records)} records for user {username_label}")
111
+
112
+ # go throught all images in the dataset and exlcude those that are already labeled by the user
113
+ remaining_indicies = set(range(dataset_size))
114
+ solved_indices = [id_to_index[x] for x in all_user_records]
115
+ remaining_indicies = remaining_indicies - set(solved_indices)
116
+
117
+ print(f"Found {len(remaining_indicies)} remaining images for user {username_label}")
118
+
119
+ return list(remaining_indicies), gr.Button(interactive=False)
120
+
121
+
122
+ def show_random_sample(username_label, remaining_batch):
123
+ rindex = random.choice(remaining_batch)
124
+ remaining_batch.remove(rindex)
125
+
126
+ # get the image
127
+ image = glitchbench_dataset[rindex]["image"]
128
+ image_id = glitchbench_dataset[rindex]["id"]
129
+
130
+ return image, image_id, "", remaining_batch
131
+
132
+
133
+ def write_user_description(username_label, image_id, user_description, skip_or_submit):
134
+ if skip_or_submit == SKIP_MESSAGE:
135
+ provided_description = "N/A"
136
+ else:
137
+ provided_description = user_description
138
+
139
+ save_json(image_id, provided_description, username_label)
140
+
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("## GlitchBench Dataset Labeling Tool")
144
+ gr.Markdown("Help us to clean and label the GlitchBench dataset.")
145
+
146
+ with gr.Row():
147
+ username_label = gr.Text(label="Username", interactive=False)
148
+ gr.LoginButton()
149
+ gr.LogoutButton()
150
+
151
+ start_button = gr.Button("Start Labeling")
152
+
153
+ username_label.attach_load_event(set_username, None)
154
+
155
+ glitch_image = gr.Image(label="Image")
156
+ glitch_image_id = gr.Textbox(label="Image ID", visible=False)
157
+
158
+ with gr.Row():
159
+ user_description = gr.Textbox(lines=5, label="Description")
160
+ with gr.Column():
161
+ submit_button = gr.Button(SUBMIT_MESSAGE)
162
+ Skip_btton = gr.Button(SKIP_MESSAGE)
163
+
164
+ remaining_batch = gr.State()
165
+
166
+ start_button.click(
167
+ start_labeling, inputs=[username_label], outputs=[remaining_batch, start_button]
168
+ ).then(
169
+ show_random_sample,
170
+ inputs=[username_label, remaining_batch],
171
+ outputs=[glitch_image, glitch_image_id, user_description],
172
+ )
173
+
174
+ submit_button.click(
175
+ write_user_description,
176
+ inputs=[username_label, glitch_image_id, user_description, submit_button],
177
+ outputs=[],
178
+ ).then(
179
+ show_random_sample,
180
+ inputs=[username_label, remaining_batch],
181
+ outputs=[glitch_image, glitch_image_id, user_description],
182
+ )
183
+
184
+ Skip_btton.click(
185
+ write_user_description,
186
+ inputs=[username_label, glitch_image_id, user_description, Skip_btton],
187
+ outputs=[],
188
+ ).then(
189
+ show_random_sample,
190
+ inputs=[username_label, remaining_batch],
191
+ outputs=[glitch_image, glitch_image_id, user_description],
192
+ )
193
+
194
+ demo.launch()