coda / app.py
justinkay
Add paper link
5c7e627
import os
os.environ['GRADIO_TEMP_DIR'] = "tmp/"
import gradio as gr
import json
import random
from PIL import Image
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
import torch
import shutil
import argparse
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from coda import CODA
from coda.datasets import Dataset
from coda.options import LOSS_FNS
from coda.oracle import Oracle
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', help='Enable debug mode with delete button')
args_cli = parser.parse_args()
DEBUG_MODE = args_cli.debug
if DEBUG_MODE:
print("Debug mode enabled - delete button will be available")
# Create deleted_in_app directory if it doesn't exist
os.makedirs('deleted_in_app', exist_ok=True)
with open('iwildcam_demo_annotations.json', 'r') as f:
data = json.load(f)
SPECIES_MAP = OrderedDict([
(24, "Jaguar"), # panthera onca
(10, "Ocelot"), # leopardus pardalis
(6, "Mountain Lion"), # puma concolor
(101, "Common Eland"), # tragelaphus oryx
(102, "Waterbuck"), # kobus ellipsiprymnus
])
NAME_TO_ID = {name: id for id, name in SPECIES_MAP.items()}
# Class names in order (0-4) from classes.txt
CLASS_NAMES = ["Jaguar", "Ocelot", "Mountain Lion", "Common Eland", "Waterbuck"]
NAME_TO_CLASS_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)}
# Model information from models.txt
MODEL_INFO = [
{"org": "Facebook", "name": "PE-Core", "logo": "logos/meta.png"},
{"org": "Google", "name": "SigLIP2", "logo": "logos/google.png"},
{"org": "OpenAI", "name": "CLIPViT-L", "logo": "logos/openai.png"},
{"org": "Imageomics", "name": "BioCLIP2", "logo": "logos/imageomics.png"},
{"org": "LAION", "name": "LAION CLIP", "logo": "logos/laion.png"}
]
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
DEMO_ALPHA = 0.9 # 0.25 # this is more fun if showing the confusion matrices
# Toggle between confusion matrix and accuracy chart
USE_CONFUSION_MATRIX = False # Set to True for confusion matrices, False for accuracy bars
def create_species_guide_content():
"""Create the species identification guide content"""
with gr.Column():
gr.Markdown("""
# Species Classification Guide
### Learn to identify the five wildlife species in this demo.
## Jaguar
""")
gr.Image("species_id/jaguar.jpg", label="Jaguar example image", show_label=False)
gr.Markdown("""
#### The largest cat in the Americas, with a stocky, muscular build and a broad head. Coat is patterned with rosettes that often have central spots inside.
----
## Ocelot
""")
gr.Image("species_id/ocelot.jpg", label="Ocelot example image", show_label=False)
gr.Markdown("""
#### Smaller and leaner than a jaguar, with more elongated markings and rounder ears.
----
## Mountain Lion
""")
gr.Image("species_id/mountainlion.jpg", label="Mountain lion example image", show_label=False)
gr.Markdown("""
#### Also called cougar or puma, this cat has a plain tawny or grayish coat without spots. Its long tail and uniformly colored fur distinguish it from jaguars and ocelots.
----
## Common Eland
""")
gr.Image("species_id/commoneland.jpg", label="Eland example image", show_label=False)
gr.Markdown("""
### The largest antelope species. Identifiable by its spiraled horns on both sexes. Lighter tan coat than a waterbuck.
----
## Waterbuck
""")
gr.Image("species_id/waterbuck.jpg", label="Waterbuck example image", show_label=False)
gr.Markdown("""
#### A shaggy, dark brown antelope. Identifiable by backward-curving horns in males, no horns on females. Larger, rounder ears and darker coat than the common eland.
----
""")
# load image metadata
images_data = []
for annotation in tqdm(data['annotations'], desc='Loading annotations'):
image_id = annotation['image_id']
category_id = annotation['category_id']
image_info = next((img for img in data['images'] if img['id'] == image_id), None)
if image_info:
images_data.append({
'filename': image_info['file_name'],
'species_id': category_id,
'species_name': SPECIES_MAP[category_id]
})
print(f"Loaded {len(images_data)} images for the quiz")
# Load image filenames list
with open('images.txt', 'r') as f:
full_image_filenames = [line.strip() for line in f.readlines() if line.strip()]
# Initialize full dataset (will be subsampled per-user)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load full dataset
full_preds = torch.load("iwildcam_demo.pt").to(device)
full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
# Pre-compute class indices for subsampling
from collections import defaultdict
full_class_to_indices = defaultdict(list)
for idx, label in enumerate(full_labels):
class_idx = label.item()
full_class_to_indices[class_idx].append(idx)
# Find minimum class size
min_class_size = min(len(indices) for indices in full_class_to_indices.values())
print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)")
# Loss function for oracle
loss_fn = LOSS_FNS['acc']
# Global state (will be set per-user in start_demo)
current_image_info = None
coda_selector = None
oracle = None
dataset = None
image_filenames = None
iteration_count = 0
def get_model_predictions(chosen_idx):
"""Get model predictions and scores for a specific image"""
global dataset
if dataset is None or chosen_idx >= dataset.preds.shape[1]:
return "No predictions available"
# Get predictions for this image (shape: [num_models, num_classes])
image_preds = dataset.preds[:, chosen_idx, :].detach().cpu().numpy()
predictions_list = []
for model_idx in range(image_preds.shape[0]):
model_scores = image_preds[model_idx]
predicted_class_idx = model_scores.argmax()
predicted_class_name = CLASS_NAMES[predicted_class_idx]
confidence = model_scores[predicted_class_idx]
model_info = MODEL_INFO[model_idx]
predictions_list.append(f"**{model_info['name']}:** {predicted_class_name} *({confidence:.3f})*")
predictions_text = "### Model Predictions\n\n" + " | ".join(predictions_list)
return predictions_text
def add_logo_to_x_axis(ax, x_pos, logo_path, model_name, height_px=35):
"""Add a logo image to x-axis next to model name"""
try:
img = mpimg.imread(logo_path)
# Calculate zoom to achieve desired height in pixels
# Rough conversion: height_px / image_height / dpi * 72
zoom = height_px / min(img.shape[0],img.shape[1]) / ax.figure.dpi * 72
imagebox = OffsetImage(img, zoom=zoom)
# Position logo to the left of the x-tick
logo_offset = -0.28 # Adjust this to move logo left/right relative to tick
y_offset = -0.08
ab = AnnotationBbox(imagebox, (x_pos + logo_offset, y_offset),
xycoords=('data', 'axes fraction'), frameon=False)
ax.add_artist(ab)
except Exception as e:
print(f"Could not load logo {logo_path}: {e}")
def get_next_coda_image():
"""Get the next image that CODA wants labeled"""
global current_image_info, coda_selector, iteration_count
# Get next item from CODA
chosen_idx, selection_prob = coda_selector.get_next_item_to_label()
print("CODA chosen_idx, selection prob:", chosen_idx, selection_prob)
# Get the corresponding image filename
if chosen_idx < len(image_filenames):
filename = image_filenames[chosen_idx]
image_path = os.path.join('iwildcam_demo_images', filename)
print("Next image is", filename)
# Find the corresponding annotation for this image
current_image_info = None
for annotation in data['annotations']:
image_id = annotation['image_id']
image_info = next((img for img in data['images'] if img['id'] == image_id), None)
if image_info and image_info['file_name'] == filename:
current_image_info = {
'filename': filename,
'species_id': annotation['category_id'],
'species_name': SPECIES_MAP[annotation['category_id']],
'chosen_idx': chosen_idx,
'selection_prob': selection_prob
}
break
try:
image = Image.open(image_path)
predictions = get_model_predictions(chosen_idx)
return image, f"Iteration {iteration_count}: CODA selected this image for labeling", predictions
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None, f"Error loading image: {e}", "No predictions available"
else:
return None, "Image index out of range", "No predictions available"
def delete_current_image():
"""Delete the current image by moving it to deleted_in_app directory"""
global current_image_info, coda_selector
if current_image_info is None:
return "No image to delete!", None, "No predictions", None, None, ""
filename = current_image_info['filename']
chosen_idx = current_image_info['chosen_idx']
source_path = os.path.join('iwildcam_demo_images', filename)
dest_path = os.path.join('deleted_in_app', filename)
try:
shutil.move(source_path, dest_path)
result = f"✓ Moved {filename} to deleted_in_app/"
print(f"Deleted image: {filename}")
# Remove from CODA's unlabeled indices without adding a label
if chosen_idx in coda_selector.unlabeled_idxs:
coda_selector.unlabeled_idxs.remove(chosen_idx)
except Exception as e:
result = f"Error deleting image: {e}"
print(f"Error deleting {filename}: {e}")
# Load next image
next_image, status, predictions = get_next_coda_image()
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
# Get updated plots
prob_plot = create_probability_chart()
accuracy_plot = create_accuracy_chart()
return result, next_image, predictions, prob_plot, accuracy_plot, status_html
def check_answer(user_choice):
"""Process user's label and update CODA"""
global current_image_info, coda_selector, iteration_count
if current_image_info is None:
return "Please load an image first!", "", None, "No predictions", None, None
correct_species = current_image_info['species_name']
chosen_idx = current_image_info['chosen_idx']
selection_prob = current_image_info['selection_prob']
# Convert user choice to class index (0-5)
if user_choice == "I don't know":
# For "I don't know", just remove from sampling without providing label
coda_selector.unlabeled_idxs.remove(chosen_idx)
result = f"The last image was skipped and will not be used for model selection. The correct species was {correct_species}. "
else:
user_class_idx = NAME_TO_CLASS_IDX.get(user_choice, NAME_TO_CLASS_IDX[correct_species])
if user_choice == correct_species:
result = f"🎉 Your last classification was correct! It was indeed a {correct_species}."
else:
result = f"❌ Your last classification was incorrect. It was a {correct_species}, not a {user_choice}. This may mislead the model selection process!"
# Update CODA with the label
coda_selector.add_label(chosen_idx, user_class_idx, selection_prob)
iteration_count += 1
# Get updated plots
prob_plot = create_probability_chart()
accuracy_plot = create_accuracy_chart()
# Load next image
next_image, status, predictions = get_next_coda_image()
# Create HTML with inline help button for status
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
return result, status_html, next_image, predictions, prob_plot, accuracy_plot
def create_probability_chart():
"""Create a bar chart showing probability each model is best"""
global coda_selector
if coda_selector is None:
# Fallback for initial state
model_labels = [info['name'] for info in MODEL_INFO]
probabilities = np.ones(len(MODEL_INFO)) / len(MODEL_INFO) # Uniform prior
else:
probs_tensor = coda_selector.get_pbest()
probabilities = probs_tensor.detach().cpu().numpy().flatten()
model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(probabilities)]]
# Find the index of the highest probability
best_idx = np.argmax(probabilities)
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
# Create colors array - highlight the best model
colors = ['orange' if i == best_idx else 'steelblue' for i in range(len(model_labels))]
bars = ax.bar(range(len(model_labels)), probabilities, color=colors, alpha=0.7)
# Add text above the highest bar
ax.text(best_idx, probabilities[best_idx] + 0.0025, 'Current best guess',
ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_ylabel('Probability model is best', fontsize=12)
ax.set_title(f'CODA Model Selection Probabilities (Iteration {iteration_count})', fontsize=12)
ax.set_ylim(np.min(probabilities) - 0.01, np.max(probabilities) + 0.02)
# Set x-axis labels and ticks
ax.set_xticks(range(len(model_labels)))
ax.set_xticklabels(model_labels, fontsize=12, ha='center')
# Add logos to x-axis
for i, model_info in enumerate(MODEL_INFO[:len(probabilities)]):
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
plt.yticks(fontsize=12)
plt.tight_layout()
# Save the figure and close it to prevent memory leaks
temp_fig = fig
plt.close(fig)
return temp_fig
def create_accuracy_chart():
"""Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle"""
global coda_selector, oracle, dataset, iteration_count
if USE_CONFUSION_MATRIX:
return create_confusion_matrix_chart()
else:
return create_accuracy_bar_chart()
def create_confusion_matrix_chart():
"""Create confusion matrix estimates for each model side by side"""
global coda_selector, iteration_count
if coda_selector is None:
# Fallback for initial state - return empty figure
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
ax.text(0.5, 0.5, 'Start demo to see confusion matrices',
ha='center', va='center', fontsize=12)
ax.axis('off')
plt.tight_layout()
temp_fig = fig
plt.close(fig)
return temp_fig
# Get confusion matrix estimates from CODA's Dirichlet distributions
dirichlets = coda_selector.dirichlets # Shape: [num_models, num_classes, num_classes]
num_models = dirichlets.shape[0]
num_classes = dirichlets.shape[1]
# Convert Dirichlet parameters to expected confusion matrices
# The expected value of a Dirichlet is alpha / sum(alpha)
confusion_matrices = []
for model_idx in range(num_models):
alpha = dirichlets[model_idx].detach().cpu().numpy()
# Normalize each row to get probabilities
conf_matrix = alpha / alpha.sum(axis=1, keepdims=True)
confusion_matrices.append(conf_matrix)
# Create subplots for each model
# Adjust width based on number of models (2.4 inches per model works well)
fig_width = num_models * 2.4
fig, axes = plt.subplots(1, num_models, figsize=(fig_width, 2.8), dpi=150)
if num_models == 1:
axes = [axes]
# Species abbreviations for axis labels
species_labels = ['Jag', 'Oce', 'M.L.', 'C.E.', 'Wat']
for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)):
# Apply square root scaling to make small values more visible
# This expands small values while still showing large values
sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix))))
# Plot confusion matrix as heatmap with sqrt-scaled values
im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1)
# Add model name as title
model_info = MODEL_INFO[model_idx]
ax.set_title(f"{model_info['name']}", fontsize=10, pad=5)
# Set axis labels
if model_idx == 0:
ax.set_ylabel('True class', fontsize=9)
ax.set_xlabel('Predicted', fontsize=9)
# Set ticks with species abbreviations
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
ax.set_xticklabels(species_labels[:num_classes], fontsize=8)
ax.set_yticklabels(species_labels[:num_classes], fontsize=8)
plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98)
plt.tight_layout()
temp_fig = fig
plt.close(fig)
return temp_fig
def create_accuracy_bar_chart():
"""Create a bar chart showing true accuracy of each model (with muted colors)"""
global oracle, dataset
if oracle is None or dataset is None:
# Fallback for initial state
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
ax.text(0.5, 0.5, 'Start demo to see model accuracies',
ha='center', va='center', fontsize=12)
ax.axis('off')
plt.tight_layout()
temp_fig = fig
plt.close(fig)
return temp_fig
true_losses = oracle.true_losses(dataset.preds)
# Convert losses to accuracies (assuming loss is 1 - accuracy)
accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]]
# Find the index of the highest accuracy
best_idx = np.argmax(accuracies)
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
# Create colors array - highlight the best model with dark reddish orange, others soft pink
colors = ['#F8481C' if i == best_idx else '#F8BBD0' for i in range(len(model_labels))]
bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.85)
# Add text above the highest bar
ax.text(best_idx, accuracies[best_idx] + 0.0025, 'True best model',
ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
ax.set_title('True Model Accuracies', fontsize=12)
y_min = np.min(accuracies) - 0.025
y_max = np.max(accuracies) + 0.05
ax.set_ylim(y_min, y_max)
# Add accuracy values in the middle of the visible portion of each bar
for i, (bar, acc) in enumerate(zip(bars, accuracies)):
# Position text in the middle of the visible part of the bar
text_y = (y_min + acc) / 2
# Use black text for all bars
text_color = '#000000'
ax.text(i, text_y, f'{acc:.3f}',
ha='center', va='center', fontsize=10, fontweight='bold', color=text_color)
# Set x-axis labels and ticks
ax.set_xticks(range(len(model_labels)))
ax.set_xticklabels(model_labels, fontsize=12, ha='center')
# Add logos to x-axis
for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
plt.yticks(fontsize=12)
plt.tight_layout()
# Save the figure and close it to prevent memory leaks
temp_fig = fig
plt.close(fig)
return temp_fig
# Create the Gradio interface
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
theme=gr.themes.Base(),
css="""
.subtle-outline {
border: 1px solid var(--border-color-primary) !important;
background: var(--background-fill-secondary) !important;
border-radius: var(--radius-lg);
padding: 1rem;
}
.subtle-outline .flex {
background-color: var(--background-fill-secondary) !important;
}
/* Light blue background for model predictions panel */
.model-predictions-panel {
border: 1px solid #6B8CBF !important;
background: #D6E4F5 !important;
border-radius: var(--radius-lg);
padding: 0.3rem !important;
margin: 0.2rem 0 !important;
}
.model-predictions-panel .flex {
background-color: #D6E4F5 !important;
padding: 0 !important;
margin: 0 !important;
}
.model-predictions-panel * {
color: #1a1a1a !important;
}
/* Popup overlay styles */
.popup-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.5);
z-index: 1000;
display: flex;
justify-content: center;
align-items: center;
}
.popup-overlay > div {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.popup-content {
background: var(--background-fill-primary) !important;
padding: 2rem !important;
border-radius: 1rem !important;
max-width: 850px;
width: 90%;
max-height: 80vh;
overflow-y: auto;
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
border: none !important;
margin: 0 !important;
color: var(--body-text-color) !important;
}
.popup-content > div {
background: var(--background-fill-primary) !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
overflow-y: visible !important;
max-height: none !important;
}
.popup-content h1,
.popup-content h2,
.popup-content h3,
.popup-content p,
.popup-content li {
color: var(--body-text-color) !important;
}
/* Ensure gradio column components don't interfere with scrolling */
.popup-content .gradio-column {
overflow-y: visible !important;
max-height: none !important;
}
/* Ensure images in popup are responsive */
.popup-content img {
max-width: 100% !important;
height: auto !important;
}
/* Center title */
.text-center {
text-align: center !important;
}
/* Right align text */
.text-right {
text-align: right !important;
}
/* Subtitle styling */
.subtitle {
text-align: center !important;
font-weight: 300 !important;
color: #666 !important;
margin-top: -0.5rem !important;
}
/* Question mark icon styling */
.panel-container {
position: relative;
}
.help-icon {
position: absolute;
top: 5px;
right: 5px;
width: 25px;
height: 25px;
background-color: #f8f9fa;
color: #6c757d;
border: 1px solid #dee2e6;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
font-size: 13px;
font-weight: 600;
z-index: 10;
transition: all 0.2s ease;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.help-icon:hover {
background-color: #e9ecef;
color: #495057;
border-color: #adb5bd;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
}
/* Help popup styles */
.help-popup-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.5);
z-index: 1001;
display: flex;
justify-content: center;
align-items: center;
}
.help-popup-overlay > div {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.help-popup-content {
background: var(--background-fill-primary) !important;
padding: 1.5rem !important;
border-radius: 0.5rem !important;
max-width: 600px;
width: 90%;
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
border: none !important;
margin: 0 !important;
color: var(--body-text-color) !important;
}
.help-popup-content > div {
background: var(--background-fill-primary) !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.help-popup-content h1,
.help-popup-content h2,
.help-popup-content h3,
.help-popup-content p,
.help-popup-content li {
color: var(--body-text-color) !important;
}
/* Inline help button */
.inline-help-btn {
display: inline-block;
width: 20px;
height: 20px;
background-color: #f8f9fa;
color: #6c757d;
border: 1px solid #dee2e6;
border-radius: 50%;
text-align: center;
line-height: 18px;
cursor: pointer;
font-size: 11px;
font-weight: 600;
margin-left: 8px;
vertical-align: middle;
transition: all 0.2s ease;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.inline-help-btn:hover {
background-color: #e9ecef;
color: #495057;
border-color: #adb5bd;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
}
#hidden-selection-help-btn {
display: none;
}
/* Reduce spacing around status text */
.status-text {
margin: 0 !important;
padding: 0 !important;
}
.status-text > div {
margin: 0 !important;
padding: 0 !important;
}
/* Compact model predictions panel */
.compact-predictions {
line-height: 1.1 !important;
margin: 0 !important;
padding: 0.1rem !important;
}
.compact-predictions p {
margin: 0.05rem 0 !important;
}
.compact-predictions h3 {
margin: 0 0 0.1rem 0 !important;
}
/* Target the subtle-outline group that contains predictions */
.subtle-outline {
padding: 0.3rem !important;
margin: 0.2rem 0 !important;
}
/* Target the column inside the outline */
.subtle-outline .flex {
padding: 0 !important;
margin: 0 !important;
}
/* Ensure text in predictions panel is visible in dark mode */
.subtle-outline * {
color: var(--body-text-color) !important;
}
""",
delete_cache=(3600,3600) # once per hour - clear old javascript
) as demo:
# Main page title
gr.Markdown("# CODA: Consensus-Driven Active Model Selection", elem_classes="text-center")
gr.Markdown("*Figure out which model is best by actively annotating data. See <a href='https://www.arxiv.org/abs/2507.23771'>the paper</a> for more details!*", elem_classes="text-center")
# Add buttons row
with gr.Row():
view_guide_button = gr.Button("📖 View Species Guide", variant="secondary", size="lg")
start_over_button = gr.Button("Start Over", variant="secondary", size="lg")
# Popup component
with gr.Group(visible=True, elem_classes="popup-overlay") as popup_overlay:
with gr.Group(elem_classes="popup-content"):
# Main intro content
intro_content = gr.Markdown("""
# CODA: Consensus-Driven Active Model Selection
## Wildlife Photo Classification Challenge
You are a wildlife ecologist who has just collected a season's worth of imagery from cameras
deployed in Africa and Central and South America. You want to know what species occur in this imagery,
and you hope to use a pre-trained classifier to give you answers quickly.
But which one should you use?
Instead of labeling a large validation set, our new method, **CODA**, enables you to perform **active model selection**.
That is, CODA uses predictions from candidate models to guide the labeling process, querying you (a species identification expert)
for labels on a select few images that will most efficiently differentiate between your candidate machine learning models.
This demo lets you try CODA yourself! First, **become a species identification expert by reading our classification guide**
so that you will be equipped to provide ground truth labels. Then, watch as CODA narrows down the best model over time
as you provide labels for the query images. You will see that with your input CODA is able to identify the best model candidate
with as few as ten (correctly) labeled images.
""")
# Species guide content (initially hidden)
with gr.Column(visible=False) as species_guide_content:
create_species_guide_content()
# Add spacing before buttons
gr.HTML("<div style='margin-top: 0.1em;'></div>")
with gr.Row():
back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
guide_button = gr.Button("View Species Classification Guide", variant="primary", size="lg")
popup_start_button = gr.Button("Start Demo", variant="secondary", size="lg")
# Help popups for panels
with gr.Group(visible=False, elem_classes="help-popup-overlay") as prob_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## CODA Model Selection Probabilities
This chart shows CODA's current confidence in each candidate classifier being the best performer.
**How to read this chart:**
- Each bar represents one of the candidate machine learning classifiers
- The height of each bar shows the probability (0-100%) that this model is the best, according to CODA
- The orange bar indicates CODA's current best guess
- As you provide more labels, CODA updates these probabilities
**What you'll see:**
- CODA initializes these probabilities based on each classifier's agreement with the consensus votes of *all* classifiers,
providing informative priors
- As you label images, some models will gain confidence while others lose it
- The goal is for one model to clearly emerge as the winner
More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)!
**What models are these?**
For this demo, we selected 5 zero-shot classifiers that would be reasonable choices for someone who
wanted to classify wildlife imagery. The models are: facebook/PE-Core-L14-336,
google/siglip2-so400m-patch16-naflex, openai/clip-vit-large-patch14, imageomics/bioclip-2, and
laion/CLIP-ViT-L-14-laion2B-s32B-b82K. Our goal is not to make any general claims about the performance
of these models but rather to provide a realistic set of candidates for demonstrating CODA.
""")
gr.HTML("<div style='margin-top: 0.1em;'></div>")
prob_help_close = gr.Button("Close", variant="secondary")
with gr.Group(visible=False, elem_classes="help-popup-overlay") as acc_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## True Model Accuracies
This chart shows the actual performance of each model on the complete dataset (only possible with oracle knowledge).
**How to read this chart:**
- Each bar represents the true accuracy of one model
- The red bar shows the actual best-performing model
- This information is hidden from CODA during the selection process
- You can compare this with CODA's estimates to see how well it's doing
**Why this matters:**
- This represents the "ground truth" that CODA is trying to discover
- In real scenarios, you wouldn't know these true accuracies beforehand
- The demo shows these to illustrate how CODA's estimates align with reality
""")
acc_help_close = gr.Button("Close", variant="secondary")
with gr.Group(visible=False, elem_classes="help-popup-overlay") as selection_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## How CODA selects images for labeling
CODA selects images that best differentiate top-performing classifiers from each other. It
does this by constructing a probabilistic model of which classifier is best (see
the plot at the bottom-left). Each iteration, CODA selects an image to be labeled
based on how much a label for that image is expected to affect the probabilistic model.
Intuitively, CODA will select images where the top classifiers disagree, since knowing the ground truth for these images will provide
the most information about which classifier is best overall.
More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)!
**What data is this?**
We selected a subset of 5 species from the iWildcam dataset, and subsampled a dataset of ~500 images for this demo.
Each refresh will generate a slightly different subset, leading to slightly different model selection
performance.
""")
gr.HTML("<div style='margin-top: 0.1em;'></div>")
selection_help_close = gr.Button("Close", variant="secondary")
# Species guide popup during demo
with gr.Group(visible=False, elem_classes="popup-overlay") as species_guide_popup:
with gr.Group(elem_classes="popup-content"):
create_species_guide_content()
# Add spacing before button
gr.HTML("<div style='margin-top: 0.1em;'></div>")
species_guide_close = gr.Button("Go back to demo", variant="primary", size="lg")
# Status display with help button and result on same row
selection_help_button = gr.Button("", visible=False, elem_id="hidden-selection-help-btn")
with gr.Row():
with gr.Column(scale=3):
status_with_help = gr.HTML("", visible=True, elem_classes="status-text")
with gr.Column(scale=2):
result_display = gr.Markdown("", visible=True, elem_classes="text-right")
with gr.Row():
image_display = gr.Image(
label="Identify this animal:",
value=None,
height=400,
width=550,
elem_id="main-image-display"
)
gr.Markdown("### What species is this?")
with gr.Row():
# Create buttons for each species
species_buttons = []
for species_name in SPECIES_MAP.values():
btn = gr.Button(species_name, variant="primary", size="lg")
species_buttons.append(btn)
# Add "I don't know" button
idk_button = gr.Button("I don't know", variant="primary", size="lg")
# Model predictions panel (full width, single line)
with gr.Group(elem_classes="model-predictions-panel"):
with gr.Column(elem_classes="flex items-center justify-center h-full"):
model_predictions_display = gr.Markdown(
"### Model Predictions\n\n*Start the demo to see model votes!*",
show_label=False,
elem_classes="text-center compact-predictions"
)
# Two panels with bar charts
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-container"):
prob_help_button = gr.Button("?", elem_classes="help-icon", size="sm")
prob_plot = gr.Plot(
value=None,
show_label=False
)
with gr.Column(scale=1):
# with gr.Group(elem_classes="panel-container"):
# acc_help_button = gr.Button("?", elem_classes="help-icon", size="sm")
# with gr.Row(elem_classes="flex-grow") as accuracy_title_row:
# gr.Markdown("""
# ## True Model Accuracies
#
# Click below to view the true model accuracies.
#
# Note you wouldn't be able to do this in the real model selection setting!
#
# """,
# elem_classes="text-center compact-predictions flex-grow")
# # Centered reveal button (initially visible)
# with gr.Group(visible=True) as accuracy_hidden_group:
# with gr.Column(elem_classes="flex items-center justify-center h-full"):
# reveal_accuracy_button = gr.Button(
# "Reveal model accuracies",
# variant="primary"
# )
# # Accuracy plot (initially hidden)
# accuracy_plot = gr.Plot(
# value=create_accuracy_chart(),
# show_label=False,
# visible=False
# )
accuracy_plot = gr.Plot(
value=create_accuracy_chart(),
show_label=False,
visible=False
)
with gr.Group(visible=True, elem_classes="subtle-outline accuracy-hidden-panel") as hidden_group:
with gr.Column(elem_classes="flex items-center justify-center h-full"):
# example of how to add spacing:
# gr.HTML("<div style='margin-top: 2.9em;'></div>")
hidden_text0 = gr.Markdown("""
## True model performance is hidden
""",
elem_classes="text-center",)
gr.HTML("<div style='margin-top: 0.25em;'></div>")
hidden_text1 = gr.Markdown("""
In this problem setting the true model performance is assumed to be unknown (that is why we want to perform model selection!)
However, for this demo, we have computed the actual accuracies of each model in order to evaluate CODA's performance.
""",
elem_classes="text-center",
)
gr.HTML("<div style='margin-top: 0.25em;'></div>")
# with gr.Row():
# with gr.Column(scale=2):
# pass
# with gr.Column(scale=1, min_width=100):
# reveal_accuracy_button = gr.Button(
# "🔍 Reveal",
# variant="secondary",
# size="lg"
# )
# with gr.Column(scale=2):
# pass
with gr.Row():
reveal_accuracy_button = gr.Button(
"🔍 Reveal True Model Accuracies",
variant="secondary",
size="lg"
)
# example of how to add spacing:
# gr.HTML("<div style='margin-top: 2.9em;'></div>")
# Add debug delete button (only visible in debug mode)
if DEBUG_MODE:
delete_button = gr.Button("🗑️ Delete Current Image", variant="stop", size="lg")
# Set up button interactions
def start_demo():
global iteration_count, coda_selector, dataset, oracle, image_filenames
# Reset the demo state
iteration_count = 0
# Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model
while True:
# Subsample dataset for this user
subsampled_indices = []
for class_idx in sorted(full_class_to_indices.keys()):
indices = full_class_to_indices[class_idx]
sampled = np.random.choice(indices, size=min_class_size, replace=False)
subsampled_indices.extend(sampled.tolist())
# Sort indices to maintain order
subsampled_indices.sort()
# Create subsampled dataset for this user
subsampled_preds = full_preds[:, subsampled_indices, :]
subsampled_labels = full_labels[subsampled_indices]
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
# Create Dataset object with subsampled data
dataset = Dataset.__new__(Dataset)
dataset.preds = subsampled_preds
dataset.labels = subsampled_labels
dataset.device = device
# Create oracle and CODA selector for this user
oracle = Oracle(dataset, loss_fn=loss_fn)
coda_selector = CODA(dataset,
learning_rate=DEMO_LEARNING_RATE,
alpha=DEMO_ALPHA)
# Check which model is initially best according to CODA
probs_tensor = coda_selector.get_pbest()
probabilities = probs_tensor.detach().cpu().numpy().flatten()
coda_best_idx = np.argmax(probabilities)
# Get true best model according to oracle
true_losses = oracle.true_losses(dataset.preds)
true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
true_best_idx = np.argmax(true_accuracies)
# Accept this subset if CODA's initial best is NOT the true best
if coda_best_idx != true_best_idx:
break
# Otherwise, loop and resample
image, status, predictions = get_next_coda_image()
prob_plot = create_probability_chart()
acc_plot = create_accuracy_chart()
# Create HTML with inline help button
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
def start_over():
global iteration_count, coda_selector, dataset, oracle, image_filenames
# Reset the demo state
iteration_count = 0
# Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model
while True:
# Subsample dataset for this user (new random subsample)
subsampled_indices = []
for class_idx in sorted(full_class_to_indices.keys()):
indices = full_class_to_indices[class_idx]
sampled = np.random.choice(indices, size=min_class_size, replace=False)
subsampled_indices.extend(sampled.tolist())
# Sort indices to maintain order
subsampled_indices.sort()
# Create subsampled dataset for this user
subsampled_preds = full_preds[:, subsampled_indices, :]
subsampled_labels = full_labels[subsampled_indices]
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
# Create Dataset object with subsampled data
dataset = Dataset.__new__(Dataset)
dataset.preds = subsampled_preds
dataset.labels = subsampled_labels
dataset.device = device
# Create oracle and CODA selector for this user
oracle = Oracle(dataset, loss_fn=loss_fn)
coda_selector = CODA(dataset,
learning_rate=DEMO_LEARNING_RATE,
alpha=DEMO_ALPHA)
# Check which model is initially best according to CODA
probs_tensor = coda_selector.get_pbest()
probabilities = probs_tensor.detach().cpu().numpy().flatten()
coda_best_idx = np.argmax(probabilities)
# Get true best model according to oracle
true_losses = oracle.true_losses(dataset.preds)
true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
true_best_idx = np.argmax(true_accuracies)
# Accept this subset if CODA's initial best is NOT the true best
if coda_best_idx != true_best_idx:
break
# Otherwise, loop and resample
# Reset all displays
prob_plot = create_probability_chart()
acc_plot = create_accuracy_chart()
return None, "Demo reset. Click 'Start CODA Demo' to begin.", "### Model Predictions\n\n*Start the demo to see model votes!*", prob_plot, acc_plot, "", gr.update(visible=True), gr.update(visible=False)
def show_species_guide():
# Show species guide, hide intro content, show back button, hide guide button
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
def show_intro():
# Show intro content, hide species guide, hide back button, show guide button
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def show_prob_help():
return gr.update(visible=True)
def hide_prob_help():
return gr.update(visible=False)
def show_acc_help():
return gr.update(visible=True)
def hide_acc_help():
return gr.update(visible=False)
def show_selection_help():
return gr.update(visible=True)
def hide_selection_help():
return gr.update(visible=False)
def show_species_guide_popup():
return gr.update(visible=True)
def hide_species_guide_popup():
return gr.update(visible=False)
def reveal_accuracies():
"""Reveal accuracy plot and hide the hidden group"""
return gr.update(visible=True), gr.update(visible=False)
popup_start_button.click(
fn=start_demo,
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, popup_overlay, result_display, selection_help_button],
js="""
() => {
console.log('=== Panel Height Matching (Dynamic) ===');
function matchPanelHeights() {
const panels = document.querySelectorAll('.panel-container');
console.log('Found .panel-container elements:', panels.length);
const leftPanel = panels[0]; // prob_plot panel
const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel
console.log('Left panel (prob):', leftPanel);
console.log('Right panel (hidden):', rightPanel);
if (leftPanel && rightPanel) {
const leftHeight = leftPanel.offsetHeight;
const rightHeight = rightPanel.offsetHeight;
const diff = leftHeight - rightHeight;
console.log('Left panel height:', leftHeight);
console.log('Right panel height:', rightHeight);
console.log('Height difference:', diff);
if (diff > 0) {
console.log('Setting right panel min-height to:', leftHeight + 'px');
rightPanel.style.minHeight = leftHeight + 'px';
rightPanel.style.display = 'flex';
rightPanel.style.flexDirection = 'column';
rightPanel.style.justifyContent = 'center';
console.log('Applied min-height and flex centering');
return true; // Success
} else {
console.log('No height adjustment needed (diff <= 0)');
return true; // Success
}
} else {
console.log('Panels not ready yet');
return false; // Not ready
}
}
// Check every 50ms for 3 seconds to catch multiple height changes
let attempts = 0;
const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes
const checkInterval = setInterval(() => {
attempts++;
console.log('Attempt', attempts, 'to match heights');
matchPanelHeights(); // Always try, don't stop early
if (attempts >= maxAttempts) {
console.log('Finished checking after 3 seconds');
clearInterval(checkInterval);
}
}, 50); // Check every 50ms
}
"""
)
start_over_button.click(
fn=start_over,
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, result_display, popup_overlay, selection_help_button],
js="""
() => {
console.log('=== Panel Height Matching (Dynamic - Start Over) ===');
function matchPanelHeights() {
const panels = document.querySelectorAll('.panel-container');
console.log('Found .panel-container elements:', panels.length);
const leftPanel = panels[0]; // prob_plot panel
const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel
console.log('Left panel (prob):', leftPanel);
console.log('Right panel (hidden):', rightPanel);
if (leftPanel && rightPanel) {
const leftHeight = leftPanel.offsetHeight;
const rightHeight = rightPanel.offsetHeight;
const diff = leftHeight - rightHeight;
console.log('Left panel height:', leftHeight);
console.log('Right panel height:', rightHeight);
console.log('Height difference:', diff);
if (diff > 0) {
console.log('Setting right panel min-height to:', leftHeight + 'px');
rightPanel.style.minHeight = leftHeight + 'px';
rightPanel.style.display = 'flex';
rightPanel.style.flexDirection = 'column';
rightPanel.style.justifyContent = 'center';
console.log('Applied min-height and flex centering');
return true; // Success
} else {
console.log('No height adjustment needed (diff <= 0)');
return true; // Success
}
} else {
console.log('Panels not ready yet');
return false; // Not ready
}
}
// Check every 50ms for 3 seconds to catch multiple height changes
let attempts = 0;
const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes
const checkInterval = setInterval(() => {
attempts++;
console.log('Attempt', attempts, 'to match heights');
matchPanelHeights(); // Always try, don't stop early
if (attempts >= maxAttempts) {
console.log('Finished checking after 3 seconds');
clearInterval(checkInterval);
}
}, 50); // Check every 50ms
}
"""
)
guide_button.click(
fn=show_species_guide,
outputs=[intro_content, species_guide_content, back_button, guide_button]
)
back_button.click(
fn=show_intro,
outputs=[intro_content, species_guide_content, back_button, guide_button]
)
# Help popup handlers
prob_help_button.click(
fn=show_prob_help,
outputs=[prob_help_popup]
)
prob_help_close.click(
fn=hide_prob_help,
outputs=[prob_help_popup]
)
# acc_help_button.click(
# fn=show_acc_help,
# outputs=[acc_help_popup]
# )
acc_help_close.click(
fn=hide_acc_help,
outputs=[acc_help_popup]
)
selection_help_button.click(
fn=show_selection_help,
outputs=[selection_help_popup]
)
selection_help_close.click(
fn=hide_selection_help,
outputs=[selection_help_popup]
)
# Reveal accuracy button handler
reveal_accuracy_button.click(
fn=reveal_accuracies,
outputs=[accuracy_plot, hidden_group]
)
# Species guide popup handlers
view_guide_button.click(
fn=show_species_guide_popup,
outputs=[species_guide_popup]
)
species_guide_close.click(
fn=hide_species_guide_popup,
outputs=[species_guide_popup]
)
for btn in species_buttons:
btn.click(
fn=check_answer,
inputs=[gr.State(btn.value)],
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
)
idk_button.click(
fn=check_answer,
inputs=[gr.State("I don't know")],
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
)
# Wire up delete button in debug mode
if DEBUG_MODE:
delete_button.click(
fn=delete_current_image,
outputs=[result_display, image_display, model_predictions_display, prob_plot, accuracy_plot, status_with_help]
)
# Add JavaScript to handle inline help button clicks and dynamic image sizing
demo.load(
lambda: None,
outputs=[],
js="""
() => {
// Handle inline help button clicks
setTimeout(() => {
document.addEventListener('click', function(e) {
if (e.target && e.target.classList.contains('inline-help-btn')) {
e.preventDefault();
e.stopPropagation();
const hiddenBtn = document.getElementById('hidden-selection-help-btn');
if (hiddenBtn) {
hiddenBtn.click();
}
}
});
}, 100);
// Dynamic image sizing (NEW VERSION)
console.log('=== IMAGE SIZING V2 LOADED ===');
function adjustImageSize() {
const imageContainer = document.getElementById('main-image-display');
if (!imageContainer) {
console.log('[V2] Image container not found');
return false;
}
const viewportHeight = window.innerHeight;
const docHeight = document.documentElement.scrollHeight;
const currentImageHeight = imageContainer.offsetHeight;
// Calculate how much we're overflowing
const overflow = docHeight - viewportHeight;
// If we're not overflowing, increase image size
// If we are overflowing, decrease image size by the overflow amount
const adjustment = -overflow - 30; // Keep padding below bottom button
const targetHeight = currentImageHeight + adjustment;
console.log('[V2] viewport:', viewportHeight, 'docHeight:', docHeight, 'currentImg:', currentImageHeight, 'overflow:', overflow, 'target:', targetHeight);
// Only apply if reasonable
if (targetHeight > 300 && targetHeight < viewportHeight - 100) {
imageContainer.style.height = targetHeight + 'px';
imageContainer.style.maxHeight = targetHeight + 'px';
console.log('[V2] Set image height to:', targetHeight + 'px');
return true;
}
return false;
}
// Run after initial load
setTimeout(adjustImageSize, 500);
// Run periodically for first 5 seconds to catch layout changes
let attempts = 0;
const interval = setInterval(() => {
attempts++;
adjustImageSize();
if (attempts >= 50) { // 50 * 100ms = 5 seconds
clearInterval(interval);
}
}, 100);
// Re-adjust on window resize
window.addEventListener('resize', adjustImageSize);
}
""",
)
if __name__ == "__main__":
demo.launch(
# share=True,
# server_port=7861,
allowed_paths=["/"],
)