explore-label-concepts / src /label_interface.py
Xmaster6y's picture
no validation split
bec033e unverified
raw
history blame
6.54 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import random
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def filter_sample(sample, concepts, username, sample_type):
has_concepts = all([sample[c] for c in concepts])
if not has_concepts:
return False
if "votes" in sample and username in sample["votes"]:
is_labelled = all([c in sample["votes"][username] for c in CONCEPTS])
else:
is_labelled = False
if sample_type == "labelled":
return is_labelled
elif sample_type == "unlabelled":
return not is_labelled
else:
raise ValueError(f"Invalid sample type: {sample_type}")
def get_next_image(
split: str,
concepts: list,
sample_type: str,
filtered_indices: dict,
selected_concepts: list,
selected_sample_type: str,
profile: gr.OAuthProfile
):
username = profile.username
if concepts != selected_concepts or sample_type != selected_sample_type:
for key, values in global_variables.all_metadata.items():
filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)]
selected_concepts = concepts
selected_sample_type = sample_type
try:
sample_idx = random.choice(filtered_indices[split])
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = global_variables.all_votes[sample["id"]][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in CONCEPTS if sample[c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
{c: sample[c] for c in CONCEPTS},
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
)
except IndexError:
gr.Warning("No image found for the selected filter.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
global_variables.update_votes(username, current_image, voted_concepts)
gr.Info("Submit success")
return get_next_image(
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile
)
def save_current_work(
profile: gr.OAuthProfile,
):
username = profile.username
global_variables.save_current_work(username)
gr.Info("Save success")
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
with gr.Row():
split = gr.Radio(
label="Split",
choices=["train", "test"],
value="train",
)
sample_type = gr.Radio(
label="Sample Type",
choices=["labelled", "unlabelled"],
value="unlabelled",
)
concepts = gr.Dropdown(
label="Concepts",
multiselect=True,
choices=CONCEPTS,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Row():
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Local Submit",
)
with gr.Row():
save_button = gr.Button(
value="Save",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
selected_concepts = gr.State([])
selected_sample_type = gr.State(None)
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
]
next_button.click(
get_next_image,
inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
save_button.click(
save_current_work,
outputs=[image]
)