"""Interface for labeling concepts in images. """ from typing import Optional import random import gradio as gr from src import global_variables from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME def filter_sample(sample, concepts, username, sample_type): has_concepts = all([sample["concepts"].get(c, False) for c in concepts]) if not has_concepts: return False if "votes" in sample and username in sample["votes"]: is_labelled = all([c in sample["votes"][username] for c in CONCEPTS]) else: is_labelled = False if sample_type == "labelled": return is_labelled elif sample_type == "unlabelled": return not is_labelled else: raise ValueError(f"Invalid sample type: {sample_type}") def get_next_image( split: str, concepts: list, sample_type: str, filtered_indices: dict, selected_concepts: list, selected_sample_type: str, profile: gr.OAuthProfile ): username = profile.username if concepts != selected_concepts or sample_type != selected_sample_type: for key, values in global_variables.all_metadata.items(): filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)] selected_concepts = concepts selected_sample_type = sample_type try: sample_idx = random.choice(filtered_indices[split]) sample = global_variables.all_metadata[split][sample_idx] image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" try: username_votes = global_variables.all_votes[sample["id"]][username] voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] unseen_concepts = [c for c in CONCEPTS if c not in username_votes] except KeyError: voted_concepts = [] unseen_concepts = [] tie_concepts = [c for c in sample["concepts"] if sample["concepts"][c] is None] return ( image_path, voted_concepts, f"{split}:{sample_idx}", sample["class"], sample["concepts"], unseen_concepts, tie_concepts, filtered_indices, selected_concepts, selected_sample_type, ) except IndexError: gr.Warning("No image found for the selected filter.") return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type def submit_label( voted_concepts: list, current_image: Optional[str], split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type, profile: gr.OAuthProfile ): username = profile.username if current_image is None: gr.Warning("No image selected.") return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type global_variables.update_votes(username, current_image, voted_concepts) gr.Info("Submit success") return get_next_image( split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type, profile ) def save_current_work( profile: gr.OAuthProfile, ): username = profile.username global_variables.save_current_work(username) gr.Info("Save success") with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "## # Image Selection", ) with gr.Row(): split = gr.Radio( label="Split", choices=["train", "validation", "test"], value="train", ) sample_type = gr.Radio( label="Sample Type", choices=["labelled", "unlabelled"], value="unlabelled", ) concepts = gr.Dropdown( label="Concepts", multiselect=True, choices=CONCEPTS, ) with gr.Group(): voted_concepts = gr.CheckboxGroup( label="Voted Concepts", choices=CONCEPTS, ) unseen_concepts = gr.CheckboxGroup( label="Previously Unseen Concepts", choices=CONCEPTS, ) tie_concepts = gr.CheckboxGroup( label="Tie Concepts", choices=CONCEPTS, ) with gr.Row(): next_button = gr.Button( value="Next", ) gr.LoginButton() submit_button = gr.Button( value="Local Submit", ) with gr.Row(): save_button = gr.Button( value="Save", ) with gr.Group(): gr.Markdown( "## # Image Info", ) im_class = gr.Textbox( label="Class", ) im_concepts = gr.JSON( label="Concepts", ) with gr.Column(): image = gr.Image( label="Image", ) current_image = gr.State(None) filtered_indices = gr.State({ split: list(range(len(global_variables.all_metadata[split]))) for split in global_variables.all_metadata }) selected_concepts = gr.State([]) selected_sample_type = gr.State(None) common_output = [ image, voted_concepts, current_image, im_class, im_concepts, unseen_concepts, tie_concepts, filtered_indices, selected_concepts, selected_sample_type, ] next_button.click( get_next_image, inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], outputs=common_output ) submit_button.click( submit_label, inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], outputs=common_output ) save_button.click( save_current_work, outputs=[image] )