explore-label-concepts / src /sample_interface.py
imenelydiaker's picture
integrate-new-concepts (#3)
36c11e4 verified
raw
history blame
5.58 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def get_image(
step: int,
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
username = profile.username
try:
int_index = int(index)
except:
gr.Warning("Error parsing index using 0")
int_index = 0
sample_idx = int_index + step
if sample_idx < 0:
gr.Warning("No previous image.")
sample_idx = 0
if sample_idx >= len(global_variables.all_metadata[split]):
gr.Warning("No next image.")
sample_idx = len(global_variables.all_metadata[split]) - 1
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = global_variables.all_votes[sample["id"]][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in CONCEPTS if sample[c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
{c: sample[c] for c in CONCEPTS},
str(sample_idx),
unseen_concepts,
tie_concepts,
filtered_indices,
)
def make_get_image(step):
def f(
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
return get_image(step, split, index, filtered_indices, profile)
return f
get_next_image = make_get_image(1)
get_prev_image = make_get_image(-1)
get_current_image = make_get_image(0)
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
index,
filtered_indices,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, index, filtered_indices
global_variables.update_votes(username, current_image, voted_concepts)
gr.Info("Submit success")
return get_next_image(
split,
index,
filtered_indices,
profile
)
def save_current_work(
profile: gr.OAuthProfile,
):
username = profile.username
global_variables.save_current_work(username)
gr.Info("Save success")
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
split = gr.Radio(
label="Split",
choices=["train", "validation", "test"],
value="train",
)
index = gr.Textbox(
value="0",
label="Index",
max_lines=1,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Row():
prev_button = gr.Button(
value="Prev",
)
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Local Submit",
)
with gr.Row():
save_button = gr.Button(
value="Save",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
index,
unseen_concepts,
tie_concepts,
filtered_indices,
]
common_input = [split, index, filtered_indices]
prev_button.click(
get_prev_image,
inputs=common_input,
outputs=common_output
)
next_button.click(
get_next_image,
inputs=common_input,
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, index, filtered_indices],
outputs=common_output
)
index.submit(
get_current_image,
inputs=common_input,
outputs=common_output,
)
save_button.click(
save_current_work,
outputs=[image]
)