explore-label-concepts / src /sample_interface.py
Xmaster6y's picture
tie and unseen concepts
813250f unverified
raw
history blame
5.97 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def get_image(
step: int,
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
username = profile.username
try:
int_index = int(index)
except:
gr.Warning("Error parsing index using 0")
int_index = 0
sample_idx = int_index + step
if sample_idx < 0:
gr.Warning("No previous image.")
sample_idx = 0
if sample_idx >= len(global_variables.all_metadata[split]):
gr.Warning("No next image.")
sample_idx = len(global_variables.all_metadata[split]) - 1
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = sample["votes"][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in sample["concepts"] if sample["concepts"][c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
sample["concepts"],
str(sample_idx),
unseen_concepts,
tie_concepts,
filtered_indices,
)
def make_get_image(step):
def f(
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
return get_image(step, split, index, filtered_indices, profile)
return f
get_next_image = make_get_image(1)
get_prev_image = make_get_image(-1)
get_current_image = make_get_image(0)
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
index,
filtered_indices,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, index, filtered_indices
current_split, idx = current_image.split(":")
idx = int(idx)
global_variables.get_metadata(current_split)
if "votes" not in global_variables.all_metadata[current_split][idx]:
global_variables.all_metadata[current_split][idx]["votes"] = {}
global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS}
vote_sum = {c: 0 for c in CONCEPTS}
new_concepts = {}
for c in CONCEPTS:
for vote in global_variables.all_metadata[current_split][idx]["votes"].values():
if c not in vote:
continue
vote_sum[c] += 2 * vote[c] - 1
new_concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None
global_variables.all_metadata[current_split][idx]["concepts"] = new_concepts
global_variables.save_metadata(current_split)
gr.Info("Submit success")
return get_next_image(
split,
index,
filtered_indices,
profile
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
split = gr.Radio(
label="Split",
choices=["train", "validation", "test"],
value="train",
)
index = gr.Textbox(
value="0",
label="Index",
max_lines=1,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Row():
prev_button = gr.Button(
value="Prev",
)
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Submit",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
index,
unseen_concepts,
tie_concepts,
filtered_indices,
]
common_input = [split, index, filtered_indices]
prev_button.click(
get_prev_image,
inputs=common_input,
outputs=common_output
)
next_button.click(
get_next_image,
inputs=common_input,
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, index, filtered_indices],
outputs=common_output
)
index.submit(
get_current_image,
inputs=common_input,
outputs=common_output,
)