import os os.environ['GRADIO_TEMP_DIR'] = "tmp/" import gradio as gr import json import random from PIL import Image from tqdm import tqdm from collections import OrderedDict import numpy as np import torch import shutil import argparse import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.image as mpimg from matplotlib.offsetbox import OffsetImage, AnnotationBbox from coda import CODA from coda.datasets import Dataset from coda.options import LOSS_FNS from coda.oracle import Oracle # Parse command line arguments parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true', help='Enable debug mode with delete button') args_cli = parser.parse_args() DEBUG_MODE = args_cli.debug if DEBUG_MODE: print("Debug mode enabled - delete button will be available") # Create deleted_in_app directory if it doesn't exist os.makedirs('deleted_in_app', exist_ok=True) with open('iwildcam_demo_annotations.json', 'r') as f: data = json.load(f) SPECIES_MAP = OrderedDict([ (24, "Jaguar"), # panthera onca (10, "Ocelot"), # leopardus pardalis (6, "Mountain Lion"), # puma concolor (101, "Common Eland"), # tragelaphus oryx (102, "Waterbuck"), # kobus ellipsiprymnus ]) NAME_TO_ID = {name: id for id, name in SPECIES_MAP.items()} # Class names in order (0-4) from classes.txt CLASS_NAMES = ["Jaguar", "Ocelot", "Mountain Lion", "Common Eland", "Waterbuck"] NAME_TO_CLASS_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)} # Model information from models.txt MODEL_INFO = [ {"org": "Facebook", "name": "PE-Core", "logo": "logos/meta.png"}, {"org": "Google", "name": "SigLIP2", "logo": "logos/google.png"}, {"org": "OpenAI", "name": "CLIPViT-L", "logo": "logos/openai.png"}, {"org": "Imageomics", "name": "BioCLIP2", "logo": "logos/imageomics.png"}, {"org": "LAION", "name": "LAION CLIP", "logo": "logos/laion.png"} ] DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun DEMO_ALPHA = 0.9 # 0.25 # this is more fun if showing the confusion matrices # Toggle between confusion matrix and accuracy chart USE_CONFUSION_MATRIX = False # Set to True for confusion matrices, False for accuracy bars def create_species_guide_content(): """Create the species identification guide content""" with gr.Column(): gr.Markdown(""" # Species Classification Guide ### Learn to identify the five wildlife species in this demo. ## Jaguar """) gr.Image("species_id/jaguar.jpg", label="Jaguar example image", show_label=False) gr.Markdown(""" #### The largest cat in the Americas, with a stocky, muscular build and a broad head. Coat is patterned with rosettes that often have central spots inside. ---- ## Ocelot """) gr.Image("species_id/ocelot.jpg", label="Ocelot example image", show_label=False) gr.Markdown(""" #### Smaller and leaner than a jaguar, with more elongated markings and rounder ears. ---- ## Mountain Lion """) gr.Image("species_id/mountainlion.jpg", label="Mountain lion example image", show_label=False) gr.Markdown(""" #### Also called cougar or puma, this cat has a plain tawny or grayish coat without spots. Its long tail and uniformly colored fur distinguish it from jaguars and ocelots. ---- ## Common Eland """) gr.Image("species_id/commoneland.jpg", label="Eland example image", show_label=False) gr.Markdown(""" ### The largest antelope species. Identifiable by its spiraled horns on both sexes. Lighter tan coat than a waterbuck. ---- ## Waterbuck """) gr.Image("species_id/waterbuck.jpg", label="Waterbuck example image", show_label=False) gr.Markdown(""" #### A shaggy, dark brown antelope. Identifiable by backward-curving horns in males, no horns on females. Larger, rounder ears and darker coat than the common eland. ---- """) # load image metadata images_data = [] for annotation in tqdm(data['annotations'], desc='Loading annotations'): image_id = annotation['image_id'] category_id = annotation['category_id'] image_info = next((img for img in data['images'] if img['id'] == image_id), None) if image_info: images_data.append({ 'filename': image_info['file_name'], 'species_id': category_id, 'species_name': SPECIES_MAP[category_id] }) print(f"Loaded {len(images_data)} images for the quiz") # Load image filenames list with open('images.txt', 'r') as f: full_image_filenames = [line.strip() for line in f.readlines() if line.strip()] # Initialize full dataset (will be subsampled per-user) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load full dataset full_preds = torch.load("iwildcam_demo.pt").to(device) full_labels = torch.load("iwildcam_demo_labels.pt").to(device) # Pre-compute class indices for subsampling from collections import defaultdict full_class_to_indices = defaultdict(list) for idx, label in enumerate(full_labels): class_idx = label.item() full_class_to_indices[class_idx].append(idx) # Find minimum class size min_class_size = min(len(indices) for indices in full_class_to_indices.values()) print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)") # Loss function for oracle loss_fn = LOSS_FNS['acc'] # Global state (will be set per-user in start_demo) current_image_info = None coda_selector = None oracle = None dataset = None image_filenames = None iteration_count = 0 def get_model_predictions(chosen_idx): """Get model predictions and scores for a specific image""" global dataset if dataset is None or chosen_idx >= dataset.preds.shape[1]: return "No predictions available" # Get predictions for this image (shape: [num_models, num_classes]) image_preds = dataset.preds[:, chosen_idx, :].detach().cpu().numpy() predictions_list = [] for model_idx in range(image_preds.shape[0]): model_scores = image_preds[model_idx] predicted_class_idx = model_scores.argmax() predicted_class_name = CLASS_NAMES[predicted_class_idx] confidence = model_scores[predicted_class_idx] model_info = MODEL_INFO[model_idx] predictions_list.append(f"**{model_info['name']}:** {predicted_class_name} *({confidence:.3f})*") predictions_text = "### Model Predictions\n\n" + " | ".join(predictions_list) return predictions_text def add_logo_to_x_axis(ax, x_pos, logo_path, model_name, height_px=35): """Add a logo image to x-axis next to model name""" try: img = mpimg.imread(logo_path) # Calculate zoom to achieve desired height in pixels # Rough conversion: height_px / image_height / dpi * 72 zoom = height_px / min(img.shape[0],img.shape[1]) / ax.figure.dpi * 72 imagebox = OffsetImage(img, zoom=zoom) # Position logo to the left of the x-tick logo_offset = -0.28 # Adjust this to move logo left/right relative to tick y_offset = -0.08 ab = AnnotationBbox(imagebox, (x_pos + logo_offset, y_offset), xycoords=('data', 'axes fraction'), frameon=False) ax.add_artist(ab) except Exception as e: print(f"Could not load logo {logo_path}: {e}") def get_next_coda_image(): """Get the next image that CODA wants labeled""" global current_image_info, coda_selector, iteration_count # Get next item from CODA chosen_idx, selection_prob = coda_selector.get_next_item_to_label() print("CODA chosen_idx, selection prob:", chosen_idx, selection_prob) # Get the corresponding image filename if chosen_idx < len(image_filenames): filename = image_filenames[chosen_idx] image_path = os.path.join('iwildcam_demo_images', filename) print("Next image is", filename) # Find the corresponding annotation for this image current_image_info = None for annotation in data['annotations']: image_id = annotation['image_id'] image_info = next((img for img in data['images'] if img['id'] == image_id), None) if image_info and image_info['file_name'] == filename: current_image_info = { 'filename': filename, 'species_id': annotation['category_id'], 'species_name': SPECIES_MAP[annotation['category_id']], 'chosen_idx': chosen_idx, 'selection_prob': selection_prob } break try: image = Image.open(image_path) predictions = get_model_predictions(chosen_idx) return image, f"Iteration {iteration_count}: CODA selected this image for labeling", predictions except Exception as e: print(f"Error loading image {image_path}: {e}") return None, f"Error loading image: {e}", "No predictions available" else: return None, "Image index out of range", "No predictions available" def delete_current_image(): """Delete the current image by moving it to deleted_in_app directory""" global current_image_info, coda_selector if current_image_info is None: return "No image to delete!", None, "No predictions", None, None, "" filename = current_image_info['filename'] chosen_idx = current_image_info['chosen_idx'] source_path = os.path.join('iwildcam_demo_images', filename) dest_path = os.path.join('deleted_in_app', filename) try: shutil.move(source_path, dest_path) result = f"✓ Moved {filename} to deleted_in_app/" print(f"Deleted image: {filename}") # Remove from CODA's unlabeled indices without adding a label if chosen_idx in coda_selector.unlabeled_idxs: coda_selector.unlabeled_idxs.remove(chosen_idx) except Exception as e: result = f"Error deleting image: {e}" print(f"Error deleting {filename}: {e}") # Load next image next_image, status, predictions = get_next_coda_image() status_html = f'{status} ?' # Get updated plots prob_plot = create_probability_chart() accuracy_plot = create_accuracy_chart() return result, next_image, predictions, prob_plot, accuracy_plot, status_html def check_answer(user_choice): """Process user's label and update CODA""" global current_image_info, coda_selector, iteration_count if current_image_info is None: return "Please load an image first!", "", None, "No predictions", None, None correct_species = current_image_info['species_name'] chosen_idx = current_image_info['chosen_idx'] selection_prob = current_image_info['selection_prob'] # Convert user choice to class index (0-5) if user_choice == "I don't know": # For "I don't know", just remove from sampling without providing label coda_selector.unlabeled_idxs.remove(chosen_idx) result = f"The last image was skipped and will not be used for model selection. The correct species was {correct_species}. " else: user_class_idx = NAME_TO_CLASS_IDX.get(user_choice, NAME_TO_CLASS_IDX[correct_species]) if user_choice == correct_species: result = f"🎉 Your last classification was correct! It was indeed a {correct_species}." else: result = f"❌ Your last classification was incorrect. It was a {correct_species}, not a {user_choice}. This may mislead the model selection process!" # Update CODA with the label coda_selector.add_label(chosen_idx, user_class_idx, selection_prob) iteration_count += 1 # Get updated plots prob_plot = create_probability_chart() accuracy_plot = create_accuracy_chart() # Load next image next_image, status, predictions = get_next_coda_image() # Create HTML with inline help button for status status_html = f'{status} ?' return result, status_html, next_image, predictions, prob_plot, accuracy_plot def create_probability_chart(): """Create a bar chart showing probability each model is best""" global coda_selector if coda_selector is None: # Fallback for initial state model_labels = [info['name'] for info in MODEL_INFO] probabilities = np.ones(len(MODEL_INFO)) / len(MODEL_INFO) # Uniform prior else: probs_tensor = coda_selector.get_pbest() probabilities = probs_tensor.detach().cpu().numpy().flatten() model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(probabilities)]] # Find the index of the highest probability best_idx = np.argmax(probabilities) fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) # Create colors array - highlight the best model colors = ['orange' if i == best_idx else 'steelblue' for i in range(len(model_labels))] bars = ax.bar(range(len(model_labels)), probabilities, color=colors, alpha=0.7) # Add text above the highest bar ax.text(best_idx, probabilities[best_idx] + 0.0025, 'Current best guess', ha='center', va='bottom', fontsize=12, fontweight='bold') ax.set_ylabel('Probability model is best', fontsize=12) ax.set_title(f'CODA Model Selection Probabilities (Iteration {iteration_count})', fontsize=12) ax.set_ylim(np.min(probabilities) - 0.01, np.max(probabilities) + 0.02) # Set x-axis labels and ticks ax.set_xticks(range(len(model_labels))) ax.set_xticklabels(model_labels, fontsize=12, ha='center') # Add logos to x-axis for i, model_info in enumerate(MODEL_INFO[:len(probabilities)]): add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name']) plt.yticks(fontsize=12) plt.tight_layout() # Save the figure and close it to prevent memory leaks temp_fig = fig plt.close(fig) return temp_fig def create_accuracy_chart(): """Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle""" global coda_selector, oracle, dataset, iteration_count if USE_CONFUSION_MATRIX: return create_confusion_matrix_chart() else: return create_accuracy_bar_chart() def create_confusion_matrix_chart(): """Create confusion matrix estimates for each model side by side""" global coda_selector, iteration_count if coda_selector is None: # Fallback for initial state - return empty figure fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) ax.text(0.5, 0.5, 'Start demo to see confusion matrices', ha='center', va='center', fontsize=12) ax.axis('off') plt.tight_layout() temp_fig = fig plt.close(fig) return temp_fig # Get confusion matrix estimates from CODA's Dirichlet distributions dirichlets = coda_selector.dirichlets # Shape: [num_models, num_classes, num_classes] num_models = dirichlets.shape[0] num_classes = dirichlets.shape[1] # Convert Dirichlet parameters to expected confusion matrices # The expected value of a Dirichlet is alpha / sum(alpha) confusion_matrices = [] for model_idx in range(num_models): alpha = dirichlets[model_idx].detach().cpu().numpy() # Normalize each row to get probabilities conf_matrix = alpha / alpha.sum(axis=1, keepdims=True) confusion_matrices.append(conf_matrix) # Create subplots for each model # Adjust width based on number of models (2.4 inches per model works well) fig_width = num_models * 2.4 fig, axes = plt.subplots(1, num_models, figsize=(fig_width, 2.8), dpi=150) if num_models == 1: axes = [axes] # Species abbreviations for axis labels species_labels = ['Jag', 'Oce', 'M.L.', 'C.E.', 'Wat'] for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)): # Apply square root scaling to make small values more visible # This expands small values while still showing large values sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix)))) # Plot confusion matrix as heatmap with sqrt-scaled values im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1) # Add model name as title model_info = MODEL_INFO[model_idx] ax.set_title(f"{model_info['name']}", fontsize=10, pad=5) # Set axis labels if model_idx == 0: ax.set_ylabel('True class', fontsize=9) ax.set_xlabel('Predicted', fontsize=9) # Set ticks with species abbreviations ax.set_xticks(range(num_classes)) ax.set_yticks(range(num_classes)) ax.set_xticklabels(species_labels[:num_classes], fontsize=8) ax.set_yticklabels(species_labels[:num_classes], fontsize=8) plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98) plt.tight_layout() temp_fig = fig plt.close(fig) return temp_fig def create_accuracy_bar_chart(): """Create a bar chart showing true accuracy of each model (with muted colors)""" global oracle, dataset if oracle is None or dataset is None: # Fallback for initial state fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) ax.text(0.5, 0.5, 'Start demo to see model accuracies', ha='center', va='center', fontsize=12) ax.axis('off') plt.tight_layout() temp_fig = fig plt.close(fig) return temp_fig true_losses = oracle.true_losses(dataset.preds) # Convert losses to accuracies (assuming loss is 1 - accuracy) accuracies = (1 - true_losses).detach().cpu().numpy().flatten() model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]] # Find the index of the highest accuracy best_idx = np.argmax(accuracies) fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) # Create colors array - highlight the best model with dark reddish orange, others soft pink colors = ['#F8481C' if i == best_idx else '#F8BBD0' for i in range(len(model_labels))] bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.85) # Add text above the highest bar ax.text(best_idx, accuracies[best_idx] + 0.0025, 'True best model', ha='center', va='bottom', fontsize=12, fontweight='bold') ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12) ax.set_title('True Model Accuracies', fontsize=12) y_min = np.min(accuracies) - 0.025 y_max = np.max(accuracies) + 0.05 ax.set_ylim(y_min, y_max) # Add accuracy values in the middle of the visible portion of each bar for i, (bar, acc) in enumerate(zip(bars, accuracies)): # Position text in the middle of the visible part of the bar text_y = (y_min + acc) / 2 # Use black text for all bars text_color = '#000000' ax.text(i, text_y, f'{acc:.3f}', ha='center', va='center', fontsize=10, fontweight='bold', color=text_color) # Set x-axis labels and ticks ax.set_xticks(range(len(model_labels))) ax.set_xticklabels(model_labels, fontsize=12, ha='center') # Add logos to x-axis for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]): add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name']) plt.yticks(fontsize=12) plt.tight_layout() # Save the figure and close it to prevent memory leaks temp_fig = fig plt.close(fig) return temp_fig # Create the Gradio interface with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge", theme=gr.themes.Base(), css=""" .subtle-outline { border: 1px solid var(--border-color-primary) !important; background: var(--background-fill-secondary) !important; border-radius: var(--radius-lg); padding: 1rem; } .subtle-outline .flex { background-color: var(--background-fill-secondary) !important; } /* Light blue background for model predictions panel */ .model-predictions-panel { border: 1px solid #6B8CBF !important; background: #D6E4F5 !important; border-radius: var(--radius-lg); padding: 0.3rem !important; margin: 0.2rem 0 !important; } .model-predictions-panel .flex { background-color: #D6E4F5 !important; padding: 0 !important; margin: 0 !important; } .model-predictions-panel * { color: #1a1a1a !important; } /* Popup overlay styles */ .popup-overlay { position: fixed; top: 0; left: 0; width: 100%; height: 100%; background-color: rgba(0, 0, 0, 0.5); z-index: 1000; display: flex; justify-content: center; align-items: center; } .popup-overlay > div { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; } .popup-content { background: var(--background-fill-primary) !important; padding: 2rem !important; border-radius: 1rem !important; max-width: 850px; width: 90%; max-height: 80vh; overflow-y: auto; box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3); border: none !important; margin: 0 !important; color: var(--body-text-color) !important; } .popup-content > div { background: var(--background-fill-primary) !important; border: none !important; padding: 0 !important; margin: 0 !important; overflow-y: visible !important; max-height: none !important; } .popup-content h1, .popup-content h2, .popup-content h3, .popup-content p, .popup-content li { color: var(--body-text-color) !important; } /* Ensure gradio column components don't interfere with scrolling */ .popup-content .gradio-column { overflow-y: visible !important; max-height: none !important; } /* Ensure images in popup are responsive */ .popup-content img { max-width: 100% !important; height: auto !important; } /* Center title */ .text-center { text-align: center !important; } /* Right align text */ .text-right { text-align: right !important; } /* Subtitle styling */ .subtitle { text-align: center !important; font-weight: 300 !important; color: #666 !important; margin-top: -0.5rem !important; } /* Question mark icon styling */ .panel-container { position: relative; } .help-icon { position: absolute; top: 5px; right: 5px; width: 25px; height: 25px; background-color: #f8f9fa; color: #6c757d; border: 1px solid #dee2e6; border-radius: 50%; display: flex; align-items: center; justify-content: center; cursor: pointer; font-size: 13px; font-weight: 600; z-index: 10; transition: all 0.2s ease; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .help-icon:hover { background-color: #e9ecef; color: #495057; border-color: #adb5bd; box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15); } /* Help popup styles */ .help-popup-overlay { position: fixed; top: 0; left: 0; width: 100%; height: 100%; background-color: rgba(0, 0, 0, 0.5); z-index: 1001; display: flex; justify-content: center; align-items: center; } .help-popup-overlay > div { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; } .help-popup-content { background: var(--background-fill-primary) !important; padding: 1.5rem !important; border-radius: 0.5rem !important; max-width: 600px; width: 90%; box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3); border: none !important; margin: 0 !important; color: var(--body-text-color) !important; } .help-popup-content > div { background: var(--background-fill-primary) !important; border: none !important; padding: 0 !important; margin: 0 !important; } .help-popup-content h1, .help-popup-content h2, .help-popup-content h3, .help-popup-content p, .help-popup-content li { color: var(--body-text-color) !important; } /* Inline help button */ .inline-help-btn { display: inline-block; width: 20px; height: 20px; background-color: #f8f9fa; color: #6c757d; border: 1px solid #dee2e6; border-radius: 50%; text-align: center; line-height: 18px; cursor: pointer; font-size: 11px; font-weight: 600; margin-left: 8px; vertical-align: middle; transition: all 0.2s ease; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .inline-help-btn:hover { background-color: #e9ecef; color: #495057; border-color: #adb5bd; box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15); } #hidden-selection-help-btn { display: none; } /* Reduce spacing around status text */ .status-text { margin: 0 !important; padding: 0 !important; } .status-text > div { margin: 0 !important; padding: 0 !important; } /* Compact model predictions panel */ .compact-predictions { line-height: 1.1 !important; margin: 0 !important; padding: 0.1rem !important; } .compact-predictions p { margin: 0.05rem 0 !important; } .compact-predictions h3 { margin: 0 0 0.1rem 0 !important; } /* Target the subtle-outline group that contains predictions */ .subtle-outline { padding: 0.3rem !important; margin: 0.2rem 0 !important; } /* Target the column inside the outline */ .subtle-outline .flex { padding: 0 !important; margin: 0 !important; } /* Ensure text in predictions panel is visible in dark mode */ .subtle-outline * { color: var(--body-text-color) !important; } """, delete_cache=(3600,3600) # once per hour - clear old javascript ) as demo: # Main page title gr.Markdown("# CODA: Consensus-Driven Active Model Selection", elem_classes="text-center") gr.Markdown("*Figure out which model is best by actively annotating data. See the paper for more details!*", elem_classes="text-center") # Add buttons row with gr.Row(): view_guide_button = gr.Button("📖 View Species Guide", variant="secondary", size="lg") start_over_button = gr.Button("Start Over", variant="secondary", size="lg") # Popup component with gr.Group(visible=True, elem_classes="popup-overlay") as popup_overlay: with gr.Group(elem_classes="popup-content"): # Main intro content intro_content = gr.Markdown(""" # CODA: Consensus-Driven Active Model Selection ## Wildlife Photo Classification Challenge You are a wildlife ecologist who has just collected a season's worth of imagery from cameras deployed in Africa and Central and South America. You want to know what species occur in this imagery, and you hope to use a pre-trained classifier to give you answers quickly. But which one should you use? Instead of labeling a large validation set, our new method, **CODA**, enables you to perform **active model selection**. That is, CODA uses predictions from candidate models to guide the labeling process, querying you (a species identification expert) for labels on a select few images that will most efficiently differentiate between your candidate machine learning models. This demo lets you try CODA yourself! First, **become a species identification expert by reading our classification guide** so that you will be equipped to provide ground truth labels. Then, watch as CODA narrows down the best model over time as you provide labels for the query images. You will see that with your input CODA is able to identify the best model candidate with as few as ten (correctly) labeled images. """) # Species guide content (initially hidden) with gr.Column(visible=False) as species_guide_content: create_species_guide_content() # Add spacing before buttons gr.HTML("
") with gr.Row(): back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False) guide_button = gr.Button("View Species Classification Guide", variant="primary", size="lg") popup_start_button = gr.Button("Start Demo", variant="secondary", size="lg") # Help popups for panels with gr.Group(visible=False, elem_classes="help-popup-overlay") as prob_help_popup: with gr.Group(elem_classes="help-popup-content"): gr.Markdown(""" ## CODA Model Selection Probabilities This chart shows CODA's current confidence in each candidate classifier being the best performer. **How to read this chart:** - Each bar represents one of the candidate machine learning classifiers - The height of each bar shows the probability (0-100%) that this model is the best, according to CODA - The orange bar indicates CODA's current best guess - As you provide more labels, CODA updates these probabilities **What you'll see:** - CODA initializes these probabilities based on each classifier's agreement with the consensus votes of *all* classifiers, providing informative priors - As you label images, some models will gain confidence while others lose it - The goal is for one model to clearly emerge as the winner More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)! **What models are these?** For this demo, we selected 5 zero-shot classifiers that would be reasonable choices for someone who wanted to classify wildlife imagery. The models are: facebook/PE-Core-L14-336, google/siglip2-so400m-patch16-naflex, openai/clip-vit-large-patch14, imageomics/bioclip-2, and laion/CLIP-ViT-L-14-laion2B-s32B-b82K. Our goal is not to make any general claims about the performance of these models but rather to provide a realistic set of candidates for demonstrating CODA. """) gr.HTML("") prob_help_close = gr.Button("Close", variant="secondary") with gr.Group(visible=False, elem_classes="help-popup-overlay") as acc_help_popup: with gr.Group(elem_classes="help-popup-content"): gr.Markdown(""" ## True Model Accuracies This chart shows the actual performance of each model on the complete dataset (only possible with oracle knowledge). **How to read this chart:** - Each bar represents the true accuracy of one model - The red bar shows the actual best-performing model - This information is hidden from CODA during the selection process - You can compare this with CODA's estimates to see how well it's doing **Why this matters:** - This represents the "ground truth" that CODA is trying to discover - In real scenarios, you wouldn't know these true accuracies beforehand - The demo shows these to illustrate how CODA's estimates align with reality """) acc_help_close = gr.Button("Close", variant="secondary") with gr.Group(visible=False, elem_classes="help-popup-overlay") as selection_help_popup: with gr.Group(elem_classes="help-popup-content"): gr.Markdown(""" ## How CODA selects images for labeling CODA selects images that best differentiate top-performing classifiers from each other. It does this by constructing a probabilistic model of which classifier is best (see the plot at the bottom-left). Each iteration, CODA selects an image to be labeled based on how much a label for that image is expected to affect the probabilistic model. Intuitively, CODA will select images where the top classifiers disagree, since knowing the ground truth for these images will provide the most information about which classifier is best overall. More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)! **What data is this?** We selected a subset of 5 species from the iWildcam dataset, and subsampled a dataset of ~500 images for this demo. Each refresh will generate a slightly different subset, leading to slightly different model selection performance. """) gr.HTML("") selection_help_close = gr.Button("Close", variant="secondary") # Species guide popup during demo with gr.Group(visible=False, elem_classes="popup-overlay") as species_guide_popup: with gr.Group(elem_classes="popup-content"): create_species_guide_content() # Add spacing before button gr.HTML("") species_guide_close = gr.Button("Go back to demo", variant="primary", size="lg") # Status display with help button and result on same row selection_help_button = gr.Button("", visible=False, elem_id="hidden-selection-help-btn") with gr.Row(): with gr.Column(scale=3): status_with_help = gr.HTML("", visible=True, elem_classes="status-text") with gr.Column(scale=2): result_display = gr.Markdown("", visible=True, elem_classes="text-right") with gr.Row(): image_display = gr.Image( label="Identify this animal:", value=None, height=400, width=550, elem_id="main-image-display" ) gr.Markdown("### What species is this?") with gr.Row(): # Create buttons for each species species_buttons = [] for species_name in SPECIES_MAP.values(): btn = gr.Button(species_name, variant="primary", size="lg") species_buttons.append(btn) # Add "I don't know" button idk_button = gr.Button("I don't know", variant="primary", size="lg") # Model predictions panel (full width, single line) with gr.Group(elem_classes="model-predictions-panel"): with gr.Column(elem_classes="flex items-center justify-center h-full"): model_predictions_display = gr.Markdown( "### Model Predictions\n\n*Start the demo to see model votes!*", show_label=False, elem_classes="text-center compact-predictions" ) # Two panels with bar charts with gr.Row(): with gr.Column(scale=1): with gr.Group(elem_classes="panel-container"): prob_help_button = gr.Button("?", elem_classes="help-icon", size="sm") prob_plot = gr.Plot( value=None, show_label=False ) with gr.Column(scale=1): # with gr.Group(elem_classes="panel-container"): # acc_help_button = gr.Button("?", elem_classes="help-icon", size="sm") # with gr.Row(elem_classes="flex-grow") as accuracy_title_row: # gr.Markdown(""" # ## True Model Accuracies # # Click below to view the true model accuracies. # # Note you wouldn't be able to do this in the real model selection setting! # # """, # elem_classes="text-center compact-predictions flex-grow") # # Centered reveal button (initially visible) # with gr.Group(visible=True) as accuracy_hidden_group: # with gr.Column(elem_classes="flex items-center justify-center h-full"): # reveal_accuracy_button = gr.Button( # "Reveal model accuracies", # variant="primary" # ) # # Accuracy plot (initially hidden) # accuracy_plot = gr.Plot( # value=create_accuracy_chart(), # show_label=False, # visible=False # ) accuracy_plot = gr.Plot( value=create_accuracy_chart(), show_label=False, visible=False ) with gr.Group(visible=True, elem_classes="subtle-outline accuracy-hidden-panel") as hidden_group: with gr.Column(elem_classes="flex items-center justify-center h-full"): # example of how to add spacing: # gr.HTML("") hidden_text0 = gr.Markdown(""" ## True model performance is hidden """, elem_classes="text-center",) gr.HTML("") hidden_text1 = gr.Markdown(""" In this problem setting the true model performance is assumed to be unknown (that is why we want to perform model selection!) However, for this demo, we have computed the actual accuracies of each model in order to evaluate CODA's performance. """, elem_classes="text-center", ) gr.HTML("") # with gr.Row(): # with gr.Column(scale=2): # pass # with gr.Column(scale=1, min_width=100): # reveal_accuracy_button = gr.Button( # "🔍 Reveal", # variant="secondary", # size="lg" # ) # with gr.Column(scale=2): # pass with gr.Row(): reveal_accuracy_button = gr.Button( "🔍 Reveal True Model Accuracies", variant="secondary", size="lg" ) # example of how to add spacing: # gr.HTML("") # Add debug delete button (only visible in debug mode) if DEBUG_MODE: delete_button = gr.Button("🗑️ Delete Current Image", variant="stop", size="lg") # Set up button interactions def start_demo(): global iteration_count, coda_selector, dataset, oracle, image_filenames # Reset the demo state iteration_count = 0 # Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model while True: # Subsample dataset for this user subsampled_indices = [] for class_idx in sorted(full_class_to_indices.keys()): indices = full_class_to_indices[class_idx] sampled = np.random.choice(indices, size=min_class_size, replace=False) subsampled_indices.extend(sampled.tolist()) # Sort indices to maintain order subsampled_indices.sort() # Create subsampled dataset for this user subsampled_preds = full_preds[:, subsampled_indices, :] subsampled_labels = full_labels[subsampled_indices] image_filenames = [full_image_filenames[idx] for idx in subsampled_indices] # Create Dataset object with subsampled data dataset = Dataset.__new__(Dataset) dataset.preds = subsampled_preds dataset.labels = subsampled_labels dataset.device = device # Create oracle and CODA selector for this user oracle = Oracle(dataset, loss_fn=loss_fn) coda_selector = CODA(dataset, learning_rate=DEMO_LEARNING_RATE, alpha=DEMO_ALPHA) # Check which model is initially best according to CODA probs_tensor = coda_selector.get_pbest() probabilities = probs_tensor.detach().cpu().numpy().flatten() coda_best_idx = np.argmax(probabilities) # Get true best model according to oracle true_losses = oracle.true_losses(dataset.preds) true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten() true_best_idx = np.argmax(true_accuracies) # Accept this subset if CODA's initial best is NOT the true best if coda_best_idx != true_best_idx: break # Otherwise, loop and resample image, status, predictions = get_next_coda_image() prob_plot = create_probability_chart() acc_plot = create_accuracy_chart() # Create HTML with inline help button status_html = f'{status} ?' return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True) def start_over(): global iteration_count, coda_selector, dataset, oracle, image_filenames # Reset the demo state iteration_count = 0 # Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model while True: # Subsample dataset for this user (new random subsample) subsampled_indices = [] for class_idx in sorted(full_class_to_indices.keys()): indices = full_class_to_indices[class_idx] sampled = np.random.choice(indices, size=min_class_size, replace=False) subsampled_indices.extend(sampled.tolist()) # Sort indices to maintain order subsampled_indices.sort() # Create subsampled dataset for this user subsampled_preds = full_preds[:, subsampled_indices, :] subsampled_labels = full_labels[subsampled_indices] image_filenames = [full_image_filenames[idx] for idx in subsampled_indices] # Create Dataset object with subsampled data dataset = Dataset.__new__(Dataset) dataset.preds = subsampled_preds dataset.labels = subsampled_labels dataset.device = device # Create oracle and CODA selector for this user oracle = Oracle(dataset, loss_fn=loss_fn) coda_selector = CODA(dataset, learning_rate=DEMO_LEARNING_RATE, alpha=DEMO_ALPHA) # Check which model is initially best according to CODA probs_tensor = coda_selector.get_pbest() probabilities = probs_tensor.detach().cpu().numpy().flatten() coda_best_idx = np.argmax(probabilities) # Get true best model according to oracle true_losses = oracle.true_losses(dataset.preds) true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten() true_best_idx = np.argmax(true_accuracies) # Accept this subset if CODA's initial best is NOT the true best if coda_best_idx != true_best_idx: break # Otherwise, loop and resample # Reset all displays prob_plot = create_probability_chart() acc_plot = create_accuracy_chart() return None, "Demo reset. Click 'Start CODA Demo' to begin.", "### Model Predictions\n\n*Start the demo to see model votes!*", prob_plot, acc_plot, "", gr.update(visible=True), gr.update(visible=False) def show_species_guide(): # Show species guide, hide intro content, show back button, hide guide button return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) def show_intro(): # Show intro content, hide species guide, hide back button, show guide button return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) def show_prob_help(): return gr.update(visible=True) def hide_prob_help(): return gr.update(visible=False) def show_acc_help(): return gr.update(visible=True) def hide_acc_help(): return gr.update(visible=False) def show_selection_help(): return gr.update(visible=True) def hide_selection_help(): return gr.update(visible=False) def show_species_guide_popup(): return gr.update(visible=True) def hide_species_guide_popup(): return gr.update(visible=False) def reveal_accuracies(): """Reveal accuracy plot and hide the hidden group""" return gr.update(visible=True), gr.update(visible=False) popup_start_button.click( fn=start_demo, outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, popup_overlay, result_display, selection_help_button], js=""" () => { console.log('=== Panel Height Matching (Dynamic) ==='); function matchPanelHeights() { const panels = document.querySelectorAll('.panel-container'); console.log('Found .panel-container elements:', panels.length); const leftPanel = panels[0]; // prob_plot panel const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel console.log('Left panel (prob):', leftPanel); console.log('Right panel (hidden):', rightPanel); if (leftPanel && rightPanel) { const leftHeight = leftPanel.offsetHeight; const rightHeight = rightPanel.offsetHeight; const diff = leftHeight - rightHeight; console.log('Left panel height:', leftHeight); console.log('Right panel height:', rightHeight); console.log('Height difference:', diff); if (diff > 0) { console.log('Setting right panel min-height to:', leftHeight + 'px'); rightPanel.style.minHeight = leftHeight + 'px'; rightPanel.style.display = 'flex'; rightPanel.style.flexDirection = 'column'; rightPanel.style.justifyContent = 'center'; console.log('Applied min-height and flex centering'); return true; // Success } else { console.log('No height adjustment needed (diff <= 0)'); return true; // Success } } else { console.log('Panels not ready yet'); return false; // Not ready } } // Check every 50ms for 3 seconds to catch multiple height changes let attempts = 0; const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes const checkInterval = setInterval(() => { attempts++; console.log('Attempt', attempts, 'to match heights'); matchPanelHeights(); // Always try, don't stop early if (attempts >= maxAttempts) { console.log('Finished checking after 3 seconds'); clearInterval(checkInterval); } }, 50); // Check every 50ms } """ ) start_over_button.click( fn=start_over, outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, result_display, popup_overlay, selection_help_button], js=""" () => { console.log('=== Panel Height Matching (Dynamic - Start Over) ==='); function matchPanelHeights() { const panels = document.querySelectorAll('.panel-container'); console.log('Found .panel-container elements:', panels.length); const leftPanel = panels[0]; // prob_plot panel const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel console.log('Left panel (prob):', leftPanel); console.log('Right panel (hidden):', rightPanel); if (leftPanel && rightPanel) { const leftHeight = leftPanel.offsetHeight; const rightHeight = rightPanel.offsetHeight; const diff = leftHeight - rightHeight; console.log('Left panel height:', leftHeight); console.log('Right panel height:', rightHeight); console.log('Height difference:', diff); if (diff > 0) { console.log('Setting right panel min-height to:', leftHeight + 'px'); rightPanel.style.minHeight = leftHeight + 'px'; rightPanel.style.display = 'flex'; rightPanel.style.flexDirection = 'column'; rightPanel.style.justifyContent = 'center'; console.log('Applied min-height and flex centering'); return true; // Success } else { console.log('No height adjustment needed (diff <= 0)'); return true; // Success } } else { console.log('Panels not ready yet'); return false; // Not ready } } // Check every 50ms for 3 seconds to catch multiple height changes let attempts = 0; const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes const checkInterval = setInterval(() => { attempts++; console.log('Attempt', attempts, 'to match heights'); matchPanelHeights(); // Always try, don't stop early if (attempts >= maxAttempts) { console.log('Finished checking after 3 seconds'); clearInterval(checkInterval); } }, 50); // Check every 50ms } """ ) guide_button.click( fn=show_species_guide, outputs=[intro_content, species_guide_content, back_button, guide_button] ) back_button.click( fn=show_intro, outputs=[intro_content, species_guide_content, back_button, guide_button] ) # Help popup handlers prob_help_button.click( fn=show_prob_help, outputs=[prob_help_popup] ) prob_help_close.click( fn=hide_prob_help, outputs=[prob_help_popup] ) # acc_help_button.click( # fn=show_acc_help, # outputs=[acc_help_popup] # ) acc_help_close.click( fn=hide_acc_help, outputs=[acc_help_popup] ) selection_help_button.click( fn=show_selection_help, outputs=[selection_help_popup] ) selection_help_close.click( fn=hide_selection_help, outputs=[selection_help_popup] ) # Reveal accuracy button handler reveal_accuracy_button.click( fn=reveal_accuracies, outputs=[accuracy_plot, hidden_group] ) # Species guide popup handlers view_guide_button.click( fn=show_species_guide_popup, outputs=[species_guide_popup] ) species_guide_close.click( fn=hide_species_guide_popup, outputs=[species_guide_popup] ) for btn in species_buttons: btn.click( fn=check_answer, inputs=[gr.State(btn.value)], outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot] ) idk_button.click( fn=check_answer, inputs=[gr.State("I don't know")], outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot] ) # Wire up delete button in debug mode if DEBUG_MODE: delete_button.click( fn=delete_current_image, outputs=[result_display, image_display, model_predictions_display, prob_plot, accuracy_plot, status_with_help] ) # Add JavaScript to handle inline help button clicks and dynamic image sizing demo.load( lambda: None, outputs=[], js=""" () => { // Handle inline help button clicks setTimeout(() => { document.addEventListener('click', function(e) { if (e.target && e.target.classList.contains('inline-help-btn')) { e.preventDefault(); e.stopPropagation(); const hiddenBtn = document.getElementById('hidden-selection-help-btn'); if (hiddenBtn) { hiddenBtn.click(); } } }); }, 100); // Dynamic image sizing (NEW VERSION) console.log('=== IMAGE SIZING V2 LOADED ==='); function adjustImageSize() { const imageContainer = document.getElementById('main-image-display'); if (!imageContainer) { console.log('[V2] Image container not found'); return false; } const viewportHeight = window.innerHeight; const docHeight = document.documentElement.scrollHeight; const currentImageHeight = imageContainer.offsetHeight; // Calculate how much we're overflowing const overflow = docHeight - viewportHeight; // If we're not overflowing, increase image size // If we are overflowing, decrease image size by the overflow amount const adjustment = -overflow - 30; // Keep padding below bottom button const targetHeight = currentImageHeight + adjustment; console.log('[V2] viewport:', viewportHeight, 'docHeight:', docHeight, 'currentImg:', currentImageHeight, 'overflow:', overflow, 'target:', targetHeight); // Only apply if reasonable if (targetHeight > 300 && targetHeight < viewportHeight - 100) { imageContainer.style.height = targetHeight + 'px'; imageContainer.style.maxHeight = targetHeight + 'px'; console.log('[V2] Set image height to:', targetHeight + 'px'); return true; } return false; } // Run after initial load setTimeout(adjustImageSize, 500); // Run periodically for first 5 seconds to catch layout changes let attempts = 0; const interval = setInterval(() => { attempts++; adjustImageSize(); if (attempts >= 50) { // 50 * 100ms = 5 seconds clearInterval(interval); } }, 100); // Re-adjust on window resize window.addEventListener('resize', adjustImageSize); } """, ) if __name__ == "__main__": demo.launch( # share=True, # server_port=7861, allowed_paths=["/"], )