|
|
import os |
|
|
os.environ['GRADIO_TEMP_DIR'] = "tmp/" |
|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
import random |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from collections import OrderedDict |
|
|
import numpy as np |
|
|
import torch |
|
|
import shutil |
|
|
import argparse |
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.image as mpimg |
|
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
|
|
|
|
|
from coda import CODA |
|
|
from coda.datasets import Dataset |
|
|
from coda.options import LOSS_FNS |
|
|
from coda.oracle import Oracle |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--debug', action='store_true', help='Enable debug mode with delete button') |
|
|
args_cli = parser.parse_args() |
|
|
DEBUG_MODE = args_cli.debug |
|
|
|
|
|
if DEBUG_MODE: |
|
|
print("Debug mode enabled - delete button will be available") |
|
|
|
|
|
os.makedirs('deleted_in_app', exist_ok=True) |
|
|
|
|
|
|
|
|
with open('iwildcam_demo_annotations.json', 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
SPECIES_MAP = OrderedDict([ |
|
|
(24, "Jaguar"), |
|
|
(10, "Ocelot"), |
|
|
(6, "Mountain Lion"), |
|
|
(101, "Common Eland"), |
|
|
(102, "Waterbuck"), |
|
|
]) |
|
|
NAME_TO_ID = {name: id for id, name in SPECIES_MAP.items()} |
|
|
|
|
|
|
|
|
CLASS_NAMES = ["Jaguar", "Ocelot", "Mountain Lion", "Common Eland", "Waterbuck"] |
|
|
NAME_TO_CLASS_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)} |
|
|
|
|
|
|
|
|
MODEL_INFO = [ |
|
|
{"org": "Facebook", "name": "PE-Core", "logo": "logos/meta.png"}, |
|
|
{"org": "Google", "name": "SigLIP2", "logo": "logos/google.png"}, |
|
|
{"org": "OpenAI", "name": "CLIPViT-L", "logo": "logos/openai.png"}, |
|
|
{"org": "Imageomics", "name": "BioCLIP2", "logo": "logos/imageomics.png"}, |
|
|
{"org": "LAION", "name": "LAION CLIP", "logo": "logos/laion.png"} |
|
|
] |
|
|
|
|
|
DEMO_LEARNING_RATE = 0.05 |
|
|
DEMO_ALPHA = 0.9 |
|
|
|
|
|
|
|
|
USE_CONFUSION_MATRIX = False |
|
|
|
|
|
def create_species_guide_content(): |
|
|
"""Create the species identification guide content""" |
|
|
with gr.Column(): |
|
|
gr.Markdown(""" |
|
|
# Species Classification Guide |
|
|
|
|
|
### Learn to identify the five wildlife species in this demo. |
|
|
|
|
|
## Jaguar |
|
|
""") |
|
|
|
|
|
gr.Image("species_id/jaguar.jpg", label="Jaguar example image", show_label=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
#### The largest cat in the Americas, with a stocky, muscular build and a broad head. Coat is patterned with rosettes that often have central spots inside. |
|
|
|
|
|
---- |
|
|
|
|
|
## Ocelot |
|
|
|
|
|
""") |
|
|
|
|
|
gr.Image("species_id/ocelot.jpg", label="Ocelot example image", show_label=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
#### Smaller and leaner than a jaguar, with more elongated markings and rounder ears. |
|
|
|
|
|
---- |
|
|
|
|
|
## Mountain Lion |
|
|
""") |
|
|
|
|
|
gr.Image("species_id/mountainlion.jpg", label="Mountain lion example image", show_label=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
#### Also called cougar or puma, this cat has a plain tawny or grayish coat without spots. Its long tail and uniformly colored fur distinguish it from jaguars and ocelots. |
|
|
|
|
|
---- |
|
|
|
|
|
## Common Eland |
|
|
|
|
|
""") |
|
|
|
|
|
gr.Image("species_id/commoneland.jpg", label="Eland example image", show_label=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### The largest antelope species. Identifiable by its spiraled horns on both sexes. Lighter tan coat than a waterbuck. |
|
|
|
|
|
---- |
|
|
|
|
|
## Waterbuck |
|
|
""") |
|
|
|
|
|
gr.Image("species_id/waterbuck.jpg", label="Waterbuck example image", show_label=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
#### A shaggy, dark brown antelope. Identifiable by backward-curving horns in males, no horns on females. Larger, rounder ears and darker coat than the common eland. |
|
|
|
|
|
---- |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
images_data = [] |
|
|
for annotation in tqdm(data['annotations'], desc='Loading annotations'): |
|
|
image_id = annotation['image_id'] |
|
|
category_id = annotation['category_id'] |
|
|
image_info = next((img for img in data['images'] if img['id'] == image_id), None) |
|
|
if image_info: |
|
|
images_data.append({ |
|
|
'filename': image_info['file_name'], |
|
|
'species_id': category_id, |
|
|
'species_name': SPECIES_MAP[category_id] |
|
|
}) |
|
|
print(f"Loaded {len(images_data)} images for the quiz") |
|
|
|
|
|
|
|
|
with open('images.txt', 'r') as f: |
|
|
full_image_filenames = [line.strip() for line in f.readlines() if line.strip()] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
full_preds = torch.load("iwildcam_demo.pt").to(device) |
|
|
full_labels = torch.load("iwildcam_demo_labels.pt").to(device) |
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
full_class_to_indices = defaultdict(list) |
|
|
for idx, label in enumerate(full_labels): |
|
|
class_idx = label.item() |
|
|
full_class_to_indices[class_idx].append(idx) |
|
|
|
|
|
|
|
|
min_class_size = min(len(indices) for indices in full_class_to_indices.values()) |
|
|
print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)") |
|
|
|
|
|
|
|
|
loss_fn = LOSS_FNS['acc'] |
|
|
|
|
|
|
|
|
current_image_info = None |
|
|
coda_selector = None |
|
|
oracle = None |
|
|
dataset = None |
|
|
image_filenames = None |
|
|
iteration_count = 0 |
|
|
|
|
|
def get_model_predictions(chosen_idx): |
|
|
"""Get model predictions and scores for a specific image""" |
|
|
global dataset |
|
|
|
|
|
if dataset is None or chosen_idx >= dataset.preds.shape[1]: |
|
|
return "No predictions available" |
|
|
|
|
|
|
|
|
image_preds = dataset.preds[:, chosen_idx, :].detach().cpu().numpy() |
|
|
|
|
|
predictions_list = [] |
|
|
|
|
|
for model_idx in range(image_preds.shape[0]): |
|
|
model_scores = image_preds[model_idx] |
|
|
predicted_class_idx = model_scores.argmax() |
|
|
predicted_class_name = CLASS_NAMES[predicted_class_idx] |
|
|
confidence = model_scores[predicted_class_idx] |
|
|
|
|
|
model_info = MODEL_INFO[model_idx] |
|
|
predictions_list.append(f"**{model_info['name']}:** {predicted_class_name} *({confidence:.3f})*") |
|
|
|
|
|
predictions_text = "### Model Predictions\n\n" + " | ".join(predictions_list) |
|
|
|
|
|
return predictions_text |
|
|
|
|
|
def add_logo_to_x_axis(ax, x_pos, logo_path, model_name, height_px=35): |
|
|
"""Add a logo image to x-axis next to model name""" |
|
|
try: |
|
|
img = mpimg.imread(logo_path) |
|
|
|
|
|
|
|
|
zoom = height_px / min(img.shape[0],img.shape[1]) / ax.figure.dpi * 72 |
|
|
imagebox = OffsetImage(img, zoom=zoom) |
|
|
|
|
|
|
|
|
logo_offset = -0.28 |
|
|
y_offset = -0.08 |
|
|
ab = AnnotationBbox(imagebox, (x_pos + logo_offset, y_offset), |
|
|
xycoords=('data', 'axes fraction'), frameon=False) |
|
|
ax.add_artist(ab) |
|
|
except Exception as e: |
|
|
print(f"Could not load logo {logo_path}: {e}") |
|
|
|
|
|
def get_next_coda_image(): |
|
|
"""Get the next image that CODA wants labeled""" |
|
|
global current_image_info, coda_selector, iteration_count |
|
|
|
|
|
|
|
|
chosen_idx, selection_prob = coda_selector.get_next_item_to_label() |
|
|
print("CODA chosen_idx, selection prob:", chosen_idx, selection_prob) |
|
|
|
|
|
|
|
|
if chosen_idx < len(image_filenames): |
|
|
filename = image_filenames[chosen_idx] |
|
|
image_path = os.path.join('iwildcam_demo_images', filename) |
|
|
print("Next image is", filename) |
|
|
|
|
|
|
|
|
current_image_info = None |
|
|
for annotation in data['annotations']: |
|
|
image_id = annotation['image_id'] |
|
|
image_info = next((img for img in data['images'] if img['id'] == image_id), None) |
|
|
if image_info and image_info['file_name'] == filename: |
|
|
current_image_info = { |
|
|
'filename': filename, |
|
|
'species_id': annotation['category_id'], |
|
|
'species_name': SPECIES_MAP[annotation['category_id']], |
|
|
'chosen_idx': chosen_idx, |
|
|
'selection_prob': selection_prob |
|
|
} |
|
|
break |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path) |
|
|
predictions = get_model_predictions(chosen_idx) |
|
|
return image, f"Iteration {iteration_count}: CODA selected this image for labeling", predictions |
|
|
except Exception as e: |
|
|
print(f"Error loading image {image_path}: {e}") |
|
|
return None, f"Error loading image: {e}", "No predictions available" |
|
|
else: |
|
|
return None, "Image index out of range", "No predictions available" |
|
|
|
|
|
def delete_current_image(): |
|
|
"""Delete the current image by moving it to deleted_in_app directory""" |
|
|
global current_image_info, coda_selector |
|
|
|
|
|
if current_image_info is None: |
|
|
return "No image to delete!", None, "No predictions", None, None, "" |
|
|
|
|
|
filename = current_image_info['filename'] |
|
|
chosen_idx = current_image_info['chosen_idx'] |
|
|
source_path = os.path.join('iwildcam_demo_images', filename) |
|
|
dest_path = os.path.join('deleted_in_app', filename) |
|
|
|
|
|
try: |
|
|
shutil.move(source_path, dest_path) |
|
|
result = f"✓ Moved {filename} to deleted_in_app/" |
|
|
print(f"Deleted image: {filename}") |
|
|
|
|
|
|
|
|
if chosen_idx in coda_selector.unlabeled_idxs: |
|
|
coda_selector.unlabeled_idxs.remove(chosen_idx) |
|
|
except Exception as e: |
|
|
result = f"Error deleting image: {e}" |
|
|
print(f"Error deleting {filename}: {e}") |
|
|
|
|
|
|
|
|
next_image, status, predictions = get_next_coda_image() |
|
|
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>' |
|
|
|
|
|
|
|
|
prob_plot = create_probability_chart() |
|
|
accuracy_plot = create_accuracy_chart() |
|
|
|
|
|
return result, next_image, predictions, prob_plot, accuracy_plot, status_html |
|
|
|
|
|
def check_answer(user_choice): |
|
|
"""Process user's label and update CODA""" |
|
|
global current_image_info, coda_selector, iteration_count |
|
|
|
|
|
if current_image_info is None: |
|
|
return "Please load an image first!", "", None, "No predictions", None, None |
|
|
|
|
|
correct_species = current_image_info['species_name'] |
|
|
chosen_idx = current_image_info['chosen_idx'] |
|
|
selection_prob = current_image_info['selection_prob'] |
|
|
|
|
|
|
|
|
if user_choice == "I don't know": |
|
|
|
|
|
coda_selector.unlabeled_idxs.remove(chosen_idx) |
|
|
result = f"The last image was skipped and will not be used for model selection. The correct species was {correct_species}. " |
|
|
else: |
|
|
user_class_idx = NAME_TO_CLASS_IDX.get(user_choice, NAME_TO_CLASS_IDX[correct_species]) |
|
|
if user_choice == correct_species: |
|
|
result = f"🎉 Your last classification was correct! It was indeed a {correct_species}." |
|
|
else: |
|
|
result = f"❌ Your last classification was incorrect. It was a {correct_species}, not a {user_choice}. This may mislead the model selection process!" |
|
|
|
|
|
|
|
|
coda_selector.add_label(chosen_idx, user_class_idx, selection_prob) |
|
|
|
|
|
iteration_count += 1 |
|
|
|
|
|
|
|
|
prob_plot = create_probability_chart() |
|
|
accuracy_plot = create_accuracy_chart() |
|
|
|
|
|
|
|
|
next_image, status, predictions = get_next_coda_image() |
|
|
|
|
|
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>' |
|
|
return result, status_html, next_image, predictions, prob_plot, accuracy_plot |
|
|
|
|
|
def create_probability_chart(): |
|
|
"""Create a bar chart showing probability each model is best""" |
|
|
global coda_selector |
|
|
|
|
|
if coda_selector is None: |
|
|
|
|
|
model_labels = [info['name'] for info in MODEL_INFO] |
|
|
probabilities = np.ones(len(MODEL_INFO)) / len(MODEL_INFO) |
|
|
else: |
|
|
probs_tensor = coda_selector.get_pbest() |
|
|
probabilities = probs_tensor.detach().cpu().numpy().flatten() |
|
|
model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(probabilities)]] |
|
|
|
|
|
|
|
|
best_idx = np.argmax(probabilities) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) |
|
|
|
|
|
|
|
|
colors = ['orange' if i == best_idx else 'steelblue' for i in range(len(model_labels))] |
|
|
bars = ax.bar(range(len(model_labels)), probabilities, color=colors, alpha=0.7) |
|
|
|
|
|
|
|
|
ax.text(best_idx, probabilities[best_idx] + 0.0025, 'Current best guess', |
|
|
ha='center', va='bottom', fontsize=12, fontweight='bold') |
|
|
|
|
|
ax.set_ylabel('Probability model is best', fontsize=12) |
|
|
ax.set_title(f'CODA Model Selection Probabilities (Iteration {iteration_count})', fontsize=12) |
|
|
ax.set_ylim(np.min(probabilities) - 0.01, np.max(probabilities) + 0.02) |
|
|
|
|
|
|
|
|
ax.set_xticks(range(len(model_labels))) |
|
|
ax.set_xticklabels(model_labels, fontsize=12, ha='center') |
|
|
|
|
|
|
|
|
for i, model_info in enumerate(MODEL_INFO[:len(probabilities)]): |
|
|
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name']) |
|
|
plt.yticks(fontsize=12) |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
temp_fig = fig |
|
|
plt.close(fig) |
|
|
return temp_fig |
|
|
|
|
|
def create_accuracy_chart(): |
|
|
"""Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle""" |
|
|
global coda_selector, oracle, dataset, iteration_count |
|
|
|
|
|
if USE_CONFUSION_MATRIX: |
|
|
return create_confusion_matrix_chart() |
|
|
else: |
|
|
return create_accuracy_bar_chart() |
|
|
|
|
|
def create_confusion_matrix_chart(): |
|
|
"""Create confusion matrix estimates for each model side by side""" |
|
|
global coda_selector, iteration_count |
|
|
|
|
|
if coda_selector is None: |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) |
|
|
ax.text(0.5, 0.5, 'Start demo to see confusion matrices', |
|
|
ha='center', va='center', fontsize=12) |
|
|
ax.axis('off') |
|
|
plt.tight_layout() |
|
|
temp_fig = fig |
|
|
plt.close(fig) |
|
|
return temp_fig |
|
|
|
|
|
|
|
|
dirichlets = coda_selector.dirichlets |
|
|
num_models = dirichlets.shape[0] |
|
|
num_classes = dirichlets.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
confusion_matrices = [] |
|
|
for model_idx in range(num_models): |
|
|
alpha = dirichlets[model_idx].detach().cpu().numpy() |
|
|
|
|
|
conf_matrix = alpha / alpha.sum(axis=1, keepdims=True) |
|
|
confusion_matrices.append(conf_matrix) |
|
|
|
|
|
|
|
|
|
|
|
fig_width = num_models * 2.4 |
|
|
fig, axes = plt.subplots(1, num_models, figsize=(fig_width, 2.8), dpi=150) |
|
|
if num_models == 1: |
|
|
axes = [axes] |
|
|
|
|
|
|
|
|
species_labels = ['Jag', 'Oce', 'M.L.', 'C.E.', 'Wat'] |
|
|
|
|
|
for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)): |
|
|
|
|
|
|
|
|
sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix)))) |
|
|
|
|
|
|
|
|
im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto') |
|
|
|
|
|
|
|
|
model_info = MODEL_INFO[model_idx] |
|
|
ax.set_title(f"{model_info['name']}", fontsize=10, pad=5) |
|
|
|
|
|
|
|
|
if model_idx == 0: |
|
|
ax.set_ylabel('True class', fontsize=9) |
|
|
ax.set_xlabel('Predicted', fontsize=9) |
|
|
|
|
|
|
|
|
ax.set_xticks(range(num_classes)) |
|
|
ax.set_yticks(range(num_classes)) |
|
|
ax.set_xticklabels(species_labels[:num_classes], fontsize=8) |
|
|
ax.set_yticklabels(species_labels[:num_classes], fontsize=8) |
|
|
|
|
|
plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98) |
|
|
plt.tight_layout() |
|
|
|
|
|
temp_fig = fig |
|
|
plt.close(fig) |
|
|
return temp_fig |
|
|
|
|
|
def create_accuracy_bar_chart(): |
|
|
"""Create a bar chart showing true accuracy of each model (with muted colors)""" |
|
|
global oracle, dataset |
|
|
|
|
|
if oracle is None or dataset is None: |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) |
|
|
ax.text(0.5, 0.5, 'Start demo to see model accuracies', |
|
|
ha='center', va='center', fontsize=12) |
|
|
ax.axis('off') |
|
|
plt.tight_layout() |
|
|
temp_fig = fig |
|
|
plt.close(fig) |
|
|
return temp_fig |
|
|
|
|
|
true_losses = oracle.true_losses(dataset.preds) |
|
|
|
|
|
accuracies = (1 - true_losses).detach().cpu().numpy().flatten() |
|
|
model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]] |
|
|
|
|
|
|
|
|
best_idx = np.argmax(accuracies) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150) |
|
|
|
|
|
|
|
|
colors = ['#F8481C' if i == best_idx else '#F8BBD0' for i in range(len(model_labels))] |
|
|
bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.85) |
|
|
|
|
|
|
|
|
ax.text(best_idx, accuracies[best_idx] + 0.0025, 'True best model', |
|
|
ha='center', va='bottom', fontsize=12, fontweight='bold') |
|
|
|
|
|
ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12) |
|
|
ax.set_title('True Model Accuracies', fontsize=12) |
|
|
y_min = np.min(accuracies) - 0.025 |
|
|
y_max = np.max(accuracies) + 0.05 |
|
|
ax.set_ylim(y_min, y_max) |
|
|
|
|
|
|
|
|
for i, (bar, acc) in enumerate(zip(bars, accuracies)): |
|
|
|
|
|
text_y = (y_min + acc) / 2 |
|
|
|
|
|
text_color = '#000000' |
|
|
ax.text(i, text_y, f'{acc:.3f}', |
|
|
ha='center', va='center', fontsize=10, fontweight='bold', color=text_color) |
|
|
|
|
|
|
|
|
ax.set_xticks(range(len(model_labels))) |
|
|
ax.set_xticklabels(model_labels, fontsize=12, ha='center') |
|
|
|
|
|
|
|
|
for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]): |
|
|
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name']) |
|
|
plt.yticks(fontsize=12) |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
temp_fig = fig |
|
|
plt.close(fig) |
|
|
return temp_fig |
|
|
|
|
|
|
|
|
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge", |
|
|
theme=gr.themes.Base(), |
|
|
css=""" |
|
|
.subtle-outline { |
|
|
border: 1px solid var(--border-color-primary) !important; |
|
|
background: var(--background-fill-secondary) !important; |
|
|
border-radius: var(--radius-lg); |
|
|
padding: 1rem; |
|
|
} |
|
|
.subtle-outline .flex { |
|
|
background-color: var(--background-fill-secondary) !important; |
|
|
} |
|
|
|
|
|
/* Light blue background for model predictions panel */ |
|
|
.model-predictions-panel { |
|
|
border: 1px solid #6B8CBF !important; |
|
|
background: #D6E4F5 !important; |
|
|
border-radius: var(--radius-lg); |
|
|
padding: 0.3rem !important; |
|
|
margin: 0.2rem 0 !important; |
|
|
} |
|
|
.model-predictions-panel .flex { |
|
|
background-color: #D6E4F5 !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
} |
|
|
.model-predictions-panel * { |
|
|
color: #1a1a1a !important; |
|
|
} |
|
|
|
|
|
/* Popup overlay styles */ |
|
|
.popup-overlay { |
|
|
position: fixed; |
|
|
top: 0; |
|
|
left: 0; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
background-color: rgba(0, 0, 0, 0.5); |
|
|
z-index: 1000; |
|
|
display: flex; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
} |
|
|
|
|
|
.popup-overlay > div { |
|
|
background: transparent !important; |
|
|
border: none !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
} |
|
|
|
|
|
.popup-content { |
|
|
background: var(--background-fill-primary) !important; |
|
|
padding: 2rem !important; |
|
|
border-radius: 1rem !important; |
|
|
max-width: 850px; |
|
|
width: 90%; |
|
|
max-height: 80vh; |
|
|
overflow-y: auto; |
|
|
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3); |
|
|
border: none !important; |
|
|
margin: 0 !important; |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
|
|
|
.popup-content > div { |
|
|
background: var(--background-fill-primary) !important; |
|
|
border: none !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
overflow-y: visible !important; |
|
|
max-height: none !important; |
|
|
} |
|
|
|
|
|
.popup-content h1, |
|
|
.popup-content h2, |
|
|
.popup-content h3, |
|
|
.popup-content p, |
|
|
.popup-content li { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
|
|
|
/* Ensure gradio column components don't interfere with scrolling */ |
|
|
.popup-content .gradio-column { |
|
|
overflow-y: visible !important; |
|
|
max-height: none !important; |
|
|
} |
|
|
|
|
|
/* Ensure images in popup are responsive */ |
|
|
.popup-content img { |
|
|
max-width: 100% !important; |
|
|
height: auto !important; |
|
|
} |
|
|
|
|
|
/* Center title */ |
|
|
.text-center { |
|
|
text-align: center !important; |
|
|
} |
|
|
|
|
|
/* Right align text */ |
|
|
.text-right { |
|
|
text-align: right !important; |
|
|
} |
|
|
|
|
|
/* Subtitle styling */ |
|
|
.subtitle { |
|
|
text-align: center !important; |
|
|
font-weight: 300 !important; |
|
|
color: #666 !important; |
|
|
margin-top: -0.5rem !important; |
|
|
} |
|
|
|
|
|
/* Question mark icon styling */ |
|
|
.panel-container { |
|
|
position: relative; |
|
|
} |
|
|
|
|
|
.help-icon { |
|
|
position: absolute; |
|
|
top: 5px; |
|
|
right: 5px; |
|
|
width: 25px; |
|
|
height: 25px; |
|
|
background-color: #f8f9fa; |
|
|
color: #6c757d; |
|
|
border: 1px solid #dee2e6; |
|
|
border-radius: 50%; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
cursor: pointer; |
|
|
font-size: 13px; |
|
|
font-weight: 600; |
|
|
z-index: 10; |
|
|
transition: all 0.2s ease; |
|
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
|
|
|
.help-icon:hover { |
|
|
background-color: #e9ecef; |
|
|
color: #495057; |
|
|
border-color: #adb5bd; |
|
|
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15); |
|
|
} |
|
|
|
|
|
/* Help popup styles */ |
|
|
.help-popup-overlay { |
|
|
position: fixed; |
|
|
top: 0; |
|
|
left: 0; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
background-color: rgba(0, 0, 0, 0.5); |
|
|
z-index: 1001; |
|
|
display: flex; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
} |
|
|
|
|
|
.help-popup-overlay > div { |
|
|
background: transparent !important; |
|
|
border: none !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
} |
|
|
|
|
|
.help-popup-content { |
|
|
background: var(--background-fill-primary) !important; |
|
|
padding: 1.5rem !important; |
|
|
border-radius: 0.5rem !important; |
|
|
max-width: 600px; |
|
|
width: 90%; |
|
|
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3); |
|
|
border: none !important; |
|
|
margin: 0 !important; |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
|
|
|
.help-popup-content > div { |
|
|
background: var(--background-fill-primary) !important; |
|
|
border: none !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
} |
|
|
|
|
|
.help-popup-content h1, |
|
|
.help-popup-content h2, |
|
|
.help-popup-content h3, |
|
|
.help-popup-content p, |
|
|
.help-popup-content li { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
|
|
|
/* Inline help button */ |
|
|
.inline-help-btn { |
|
|
display: inline-block; |
|
|
width: 20px; |
|
|
height: 20px; |
|
|
background-color: #f8f9fa; |
|
|
color: #6c757d; |
|
|
border: 1px solid #dee2e6; |
|
|
border-radius: 50%; |
|
|
text-align: center; |
|
|
line-height: 18px; |
|
|
cursor: pointer; |
|
|
font-size: 11px; |
|
|
font-weight: 600; |
|
|
margin-left: 8px; |
|
|
vertical-align: middle; |
|
|
transition: all 0.2s ease; |
|
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
|
|
|
.inline-help-btn:hover { |
|
|
background-color: #e9ecef; |
|
|
color: #495057; |
|
|
border-color: #adb5bd; |
|
|
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15); |
|
|
} |
|
|
|
|
|
#hidden-selection-help-btn { |
|
|
display: none; |
|
|
} |
|
|
|
|
|
/* Reduce spacing around status text */ |
|
|
.status-text { |
|
|
margin: 0 !important; |
|
|
padding: 0 !important; |
|
|
} |
|
|
|
|
|
.status-text > div { |
|
|
margin: 0 !important; |
|
|
padding: 0 !important; |
|
|
} |
|
|
|
|
|
/* Compact model predictions panel */ |
|
|
.compact-predictions { |
|
|
line-height: 1.1 !important; |
|
|
margin: 0 !important; |
|
|
padding: 0.1rem !important; |
|
|
} |
|
|
|
|
|
.compact-predictions p { |
|
|
margin: 0.05rem 0 !important; |
|
|
} |
|
|
|
|
|
.compact-predictions h3 { |
|
|
margin: 0 0 0.1rem 0 !important; |
|
|
} |
|
|
|
|
|
/* Target the subtle-outline group that contains predictions */ |
|
|
.subtle-outline { |
|
|
padding: 0.3rem !important; |
|
|
margin: 0.2rem 0 !important; |
|
|
} |
|
|
|
|
|
/* Target the column inside the outline */ |
|
|
.subtle-outline .flex { |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
} |
|
|
|
|
|
/* Ensure text in predictions panel is visible in dark mode */ |
|
|
.subtle-outline * { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
""", |
|
|
delete_cache=(3600,3600) |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown("# CODA: Consensus-Driven Active Model Selection", elem_classes="text-center") |
|
|
gr.Markdown("*Figure out which model is best by actively annotating data. See <a href='https://www.arxiv.org/abs/2507.23771'>the paper</a> for more details!*", elem_classes="text-center") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
view_guide_button = gr.Button("📖 View Species Guide", variant="secondary", size="lg") |
|
|
start_over_button = gr.Button("Start Over", variant="secondary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Group(visible=True, elem_classes="popup-overlay") as popup_overlay: |
|
|
with gr.Group(elem_classes="popup-content"): |
|
|
|
|
|
intro_content = gr.Markdown(""" |
|
|
# CODA: Consensus-Driven Active Model Selection |
|
|
|
|
|
## Wildlife Photo Classification Challenge |
|
|
|
|
|
You are a wildlife ecologist who has just collected a season's worth of imagery from cameras |
|
|
deployed in Africa and Central and South America. You want to know what species occur in this imagery, |
|
|
and you hope to use a pre-trained classifier to give you answers quickly. |
|
|
But which one should you use? |
|
|
|
|
|
Instead of labeling a large validation set, our new method, **CODA**, enables you to perform **active model selection**. |
|
|
That is, CODA uses predictions from candidate models to guide the labeling process, querying you (a species identification expert) |
|
|
for labels on a select few images that will most efficiently differentiate between your candidate machine learning models. |
|
|
|
|
|
This demo lets you try CODA yourself! First, **become a species identification expert by reading our classification guide** |
|
|
so that you will be equipped to provide ground truth labels. Then, watch as CODA narrows down the best model over time |
|
|
as you provide labels for the query images. You will see that with your input CODA is able to identify the best model candidate |
|
|
with as few as ten (correctly) labeled images. |
|
|
|
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as species_guide_content: |
|
|
create_species_guide_content() |
|
|
|
|
|
|
|
|
gr.HTML("<div style='margin-top: 0.1em;'></div>") |
|
|
|
|
|
with gr.Row(): |
|
|
back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False) |
|
|
guide_button = gr.Button("View Species Classification Guide", variant="primary", size="lg") |
|
|
popup_start_button = gr.Button("Start Demo", variant="secondary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Group(visible=False, elem_classes="help-popup-overlay") as prob_help_popup: |
|
|
with gr.Group(elem_classes="help-popup-content"): |
|
|
gr.Markdown(""" |
|
|
## CODA Model Selection Probabilities |
|
|
|
|
|
This chart shows CODA's current confidence in each candidate classifier being the best performer. |
|
|
|
|
|
**How to read this chart:** |
|
|
- Each bar represents one of the candidate machine learning classifiers |
|
|
- The height of each bar shows the probability (0-100%) that this model is the best, according to CODA |
|
|
- The orange bar indicates CODA's current best guess |
|
|
- As you provide more labels, CODA updates these probabilities |
|
|
|
|
|
**What you'll see:** |
|
|
- CODA initializes these probabilities based on each classifier's agreement with the consensus votes of *all* classifiers, |
|
|
providing informative priors |
|
|
- As you label images, some models will gain confidence while others lose it |
|
|
- The goal is for one model to clearly emerge as the winner |
|
|
|
|
|
More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)! |
|
|
|
|
|
**What models are these?** |
|
|
|
|
|
For this demo, we selected 5 zero-shot classifiers that would be reasonable choices for someone who |
|
|
wanted to classify wildlife imagery. The models are: facebook/PE-Core-L14-336, |
|
|
google/siglip2-so400m-patch16-naflex, openai/clip-vit-large-patch14, imageomics/bioclip-2, and |
|
|
laion/CLIP-ViT-L-14-laion2B-s32B-b82K. Our goal is not to make any general claims about the performance |
|
|
of these models but rather to provide a realistic set of candidates for demonstrating CODA. |
|
|
|
|
|
""") |
|
|
gr.HTML("<div style='margin-top: 0.1em;'></div>") |
|
|
prob_help_close = gr.Button("Close", variant="secondary") |
|
|
|
|
|
with gr.Group(visible=False, elem_classes="help-popup-overlay") as acc_help_popup: |
|
|
with gr.Group(elem_classes="help-popup-content"): |
|
|
gr.Markdown(""" |
|
|
## True Model Accuracies |
|
|
|
|
|
This chart shows the actual performance of each model on the complete dataset (only possible with oracle knowledge). |
|
|
|
|
|
**How to read this chart:** |
|
|
- Each bar represents the true accuracy of one model |
|
|
- The red bar shows the actual best-performing model |
|
|
- This information is hidden from CODA during the selection process |
|
|
- You can compare this with CODA's estimates to see how well it's doing |
|
|
|
|
|
**Why this matters:** |
|
|
- This represents the "ground truth" that CODA is trying to discover |
|
|
- In real scenarios, you wouldn't know these true accuracies beforehand |
|
|
- The demo shows these to illustrate how CODA's estimates align with reality |
|
|
|
|
|
""") |
|
|
acc_help_close = gr.Button("Close", variant="secondary") |
|
|
|
|
|
with gr.Group(visible=False, elem_classes="help-popup-overlay") as selection_help_popup: |
|
|
with gr.Group(elem_classes="help-popup-content"): |
|
|
gr.Markdown(""" |
|
|
## How CODA selects images for labeling |
|
|
|
|
|
CODA selects images that best differentiate top-performing classifiers from each other. It |
|
|
does this by constructing a probabilistic model of which classifier is best (see |
|
|
the plot at the bottom-left). Each iteration, CODA selects an image to be labeled |
|
|
based on how much a label for that image is expected to affect the probabilistic model. |
|
|
|
|
|
Intuitively, CODA will select images where the top classifiers disagree, since knowing the ground truth for these images will provide |
|
|
the most information about which classifier is best overall. |
|
|
|
|
|
More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)! |
|
|
|
|
|
**What data is this?** |
|
|
|
|
|
We selected a subset of 5 species from the iWildcam dataset, and subsampled a dataset of ~500 images for this demo. |
|
|
Each refresh will generate a slightly different subset, leading to slightly different model selection |
|
|
performance. |
|
|
|
|
|
""") |
|
|
gr.HTML("<div style='margin-top: 0.1em;'></div>") |
|
|
|
|
|
selection_help_close = gr.Button("Close", variant="secondary") |
|
|
|
|
|
|
|
|
with gr.Group(visible=False, elem_classes="popup-overlay") as species_guide_popup: |
|
|
with gr.Group(elem_classes="popup-content"): |
|
|
create_species_guide_content() |
|
|
|
|
|
|
|
|
gr.HTML("<div style='margin-top: 0.1em;'></div>") |
|
|
|
|
|
species_guide_close = gr.Button("Go back to demo", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
selection_help_button = gr.Button("", visible=False, elem_id="hidden-selection-help-btn") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
status_with_help = gr.HTML("", visible=True, elem_classes="status-text") |
|
|
with gr.Column(scale=2): |
|
|
result_display = gr.Markdown("", visible=True, elem_classes="text-right") |
|
|
|
|
|
with gr.Row(): |
|
|
image_display = gr.Image( |
|
|
label="Identify this animal:", |
|
|
value=None, |
|
|
height=400, |
|
|
width=550, |
|
|
elem_id="main-image-display" |
|
|
) |
|
|
|
|
|
gr.Markdown("### What species is this?") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
species_buttons = [] |
|
|
for species_name in SPECIES_MAP.values(): |
|
|
btn = gr.Button(species_name, variant="primary", size="lg") |
|
|
species_buttons.append(btn) |
|
|
|
|
|
|
|
|
idk_button = gr.Button("I don't know", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="model-predictions-panel"): |
|
|
with gr.Column(elem_classes="flex items-center justify-center h-full"): |
|
|
model_predictions_display = gr.Markdown( |
|
|
"### Model Predictions\n\n*Start the demo to see model votes!*", |
|
|
show_label=False, |
|
|
elem_classes="text-center compact-predictions" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="panel-container"): |
|
|
prob_help_button = gr.Button("?", elem_classes="help-icon", size="sm") |
|
|
prob_plot = gr.Plot( |
|
|
value=None, |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accuracy_plot = gr.Plot( |
|
|
value=create_accuracy_chart(), |
|
|
show_label=False, |
|
|
visible=False |
|
|
) |
|
|
with gr.Group(visible=True, elem_classes="subtle-outline accuracy-hidden-panel") as hidden_group: |
|
|
with gr.Column(elem_classes="flex items-center justify-center h-full"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_text0 = gr.Markdown(""" |
|
|
## True model performance is hidden |
|
|
""", |
|
|
elem_classes="text-center",) |
|
|
|
|
|
gr.HTML("<div style='margin-top: 0.25em;'></div>") |
|
|
|
|
|
hidden_text1 = gr.Markdown(""" |
|
|
In this problem setting the true model performance is assumed to be unknown (that is why we want to perform model selection!) |
|
|
|
|
|
However, for this demo, we have computed the actual accuracies of each model in order to evaluate CODA's performance. |
|
|
|
|
|
""", |
|
|
elem_classes="text-center", |
|
|
) |
|
|
gr.HTML("<div style='margin-top: 0.25em;'></div>") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
reveal_accuracy_button = gr.Button( |
|
|
"🔍 Reveal True Model Accuracies", |
|
|
variant="secondary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if DEBUG_MODE: |
|
|
delete_button = gr.Button("🗑️ Delete Current Image", variant="stop", size="lg") |
|
|
|
|
|
|
|
|
def start_demo(): |
|
|
global iteration_count, coda_selector, dataset, oracle, image_filenames |
|
|
|
|
|
|
|
|
iteration_count = 0 |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
subsampled_indices = [] |
|
|
for class_idx in sorted(full_class_to_indices.keys()): |
|
|
indices = full_class_to_indices[class_idx] |
|
|
sampled = np.random.choice(indices, size=min_class_size, replace=False) |
|
|
subsampled_indices.extend(sampled.tolist()) |
|
|
|
|
|
|
|
|
subsampled_indices.sort() |
|
|
|
|
|
|
|
|
subsampled_preds = full_preds[:, subsampled_indices, :] |
|
|
subsampled_labels = full_labels[subsampled_indices] |
|
|
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices] |
|
|
|
|
|
|
|
|
dataset = Dataset.__new__(Dataset) |
|
|
dataset.preds = subsampled_preds |
|
|
dataset.labels = subsampled_labels |
|
|
dataset.device = device |
|
|
|
|
|
|
|
|
oracle = Oracle(dataset, loss_fn=loss_fn) |
|
|
coda_selector = CODA(dataset, |
|
|
learning_rate=DEMO_LEARNING_RATE, |
|
|
alpha=DEMO_ALPHA) |
|
|
|
|
|
|
|
|
probs_tensor = coda_selector.get_pbest() |
|
|
probabilities = probs_tensor.detach().cpu().numpy().flatten() |
|
|
coda_best_idx = np.argmax(probabilities) |
|
|
|
|
|
|
|
|
true_losses = oracle.true_losses(dataset.preds) |
|
|
true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten() |
|
|
true_best_idx = np.argmax(true_accuracies) |
|
|
|
|
|
|
|
|
if coda_best_idx != true_best_idx: |
|
|
break |
|
|
|
|
|
|
|
|
image, status, predictions = get_next_coda_image() |
|
|
prob_plot = create_probability_chart() |
|
|
acc_plot = create_accuracy_chart() |
|
|
|
|
|
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>' |
|
|
return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True) |
|
|
|
|
|
def start_over(): |
|
|
global iteration_count, coda_selector, dataset, oracle, image_filenames |
|
|
|
|
|
|
|
|
iteration_count = 0 |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
subsampled_indices = [] |
|
|
for class_idx in sorted(full_class_to_indices.keys()): |
|
|
indices = full_class_to_indices[class_idx] |
|
|
sampled = np.random.choice(indices, size=min_class_size, replace=False) |
|
|
subsampled_indices.extend(sampled.tolist()) |
|
|
|
|
|
|
|
|
subsampled_indices.sort() |
|
|
|
|
|
|
|
|
subsampled_preds = full_preds[:, subsampled_indices, :] |
|
|
subsampled_labels = full_labels[subsampled_indices] |
|
|
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices] |
|
|
|
|
|
|
|
|
dataset = Dataset.__new__(Dataset) |
|
|
dataset.preds = subsampled_preds |
|
|
dataset.labels = subsampled_labels |
|
|
dataset.device = device |
|
|
|
|
|
|
|
|
oracle = Oracle(dataset, loss_fn=loss_fn) |
|
|
coda_selector = CODA(dataset, |
|
|
learning_rate=DEMO_LEARNING_RATE, |
|
|
alpha=DEMO_ALPHA) |
|
|
|
|
|
|
|
|
probs_tensor = coda_selector.get_pbest() |
|
|
probabilities = probs_tensor.detach().cpu().numpy().flatten() |
|
|
coda_best_idx = np.argmax(probabilities) |
|
|
|
|
|
|
|
|
true_losses = oracle.true_losses(dataset.preds) |
|
|
true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten() |
|
|
true_best_idx = np.argmax(true_accuracies) |
|
|
|
|
|
|
|
|
if coda_best_idx != true_best_idx: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
prob_plot = create_probability_chart() |
|
|
acc_plot = create_accuracy_chart() |
|
|
return None, "Demo reset. Click 'Start CODA Demo' to begin.", "### Model Predictions\n\n*Start the demo to see model votes!*", prob_plot, acc_plot, "", gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
def show_species_guide(): |
|
|
|
|
|
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
def show_intro(): |
|
|
|
|
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
def show_prob_help(): |
|
|
return gr.update(visible=True) |
|
|
|
|
|
def hide_prob_help(): |
|
|
return gr.update(visible=False) |
|
|
|
|
|
def show_acc_help(): |
|
|
return gr.update(visible=True) |
|
|
|
|
|
def hide_acc_help(): |
|
|
return gr.update(visible=False) |
|
|
|
|
|
def show_selection_help(): |
|
|
return gr.update(visible=True) |
|
|
|
|
|
def hide_selection_help(): |
|
|
return gr.update(visible=False) |
|
|
|
|
|
def show_species_guide_popup(): |
|
|
return gr.update(visible=True) |
|
|
|
|
|
def hide_species_guide_popup(): |
|
|
return gr.update(visible=False) |
|
|
|
|
|
def reveal_accuracies(): |
|
|
"""Reveal accuracy plot and hide the hidden group""" |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
popup_start_button.click( |
|
|
fn=start_demo, |
|
|
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, popup_overlay, result_display, selection_help_button], |
|
|
js=""" |
|
|
() => { |
|
|
console.log('=== Panel Height Matching (Dynamic) ==='); |
|
|
|
|
|
function matchPanelHeights() { |
|
|
const panels = document.querySelectorAll('.panel-container'); |
|
|
console.log('Found .panel-container elements:', panels.length); |
|
|
const leftPanel = panels[0]; // prob_plot panel |
|
|
const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel |
|
|
|
|
|
console.log('Left panel (prob):', leftPanel); |
|
|
console.log('Right panel (hidden):', rightPanel); |
|
|
|
|
|
if (leftPanel && rightPanel) { |
|
|
const leftHeight = leftPanel.offsetHeight; |
|
|
const rightHeight = rightPanel.offsetHeight; |
|
|
const diff = leftHeight - rightHeight; |
|
|
|
|
|
console.log('Left panel height:', leftHeight); |
|
|
console.log('Right panel height:', rightHeight); |
|
|
console.log('Height difference:', diff); |
|
|
|
|
|
if (diff > 0) { |
|
|
console.log('Setting right panel min-height to:', leftHeight + 'px'); |
|
|
|
|
|
rightPanel.style.minHeight = leftHeight + 'px'; |
|
|
rightPanel.style.display = 'flex'; |
|
|
rightPanel.style.flexDirection = 'column'; |
|
|
rightPanel.style.justifyContent = 'center'; |
|
|
|
|
|
console.log('Applied min-height and flex centering'); |
|
|
return true; // Success |
|
|
} else { |
|
|
console.log('No height adjustment needed (diff <= 0)'); |
|
|
return true; // Success |
|
|
} |
|
|
} else { |
|
|
console.log('Panels not ready yet'); |
|
|
return false; // Not ready |
|
|
} |
|
|
} |
|
|
|
|
|
// Check every 50ms for 3 seconds to catch multiple height changes |
|
|
let attempts = 0; |
|
|
const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes |
|
|
const checkInterval = setInterval(() => { |
|
|
attempts++; |
|
|
console.log('Attempt', attempts, 'to match heights'); |
|
|
|
|
|
matchPanelHeights(); // Always try, don't stop early |
|
|
|
|
|
if (attempts >= maxAttempts) { |
|
|
console.log('Finished checking after 3 seconds'); |
|
|
clearInterval(checkInterval); |
|
|
} |
|
|
}, 50); // Check every 50ms |
|
|
} |
|
|
""" |
|
|
) |
|
|
|
|
|
start_over_button.click( |
|
|
fn=start_over, |
|
|
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, result_display, popup_overlay, selection_help_button], |
|
|
js=""" |
|
|
() => { |
|
|
console.log('=== Panel Height Matching (Dynamic - Start Over) ==='); |
|
|
|
|
|
function matchPanelHeights() { |
|
|
const panels = document.querySelectorAll('.panel-container'); |
|
|
console.log('Found .panel-container elements:', panels.length); |
|
|
const leftPanel = panels[0]; // prob_plot panel |
|
|
const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel |
|
|
|
|
|
console.log('Left panel (prob):', leftPanel); |
|
|
console.log('Right panel (hidden):', rightPanel); |
|
|
|
|
|
if (leftPanel && rightPanel) { |
|
|
const leftHeight = leftPanel.offsetHeight; |
|
|
const rightHeight = rightPanel.offsetHeight; |
|
|
const diff = leftHeight - rightHeight; |
|
|
|
|
|
console.log('Left panel height:', leftHeight); |
|
|
console.log('Right panel height:', rightHeight); |
|
|
console.log('Height difference:', diff); |
|
|
|
|
|
if (diff > 0) { |
|
|
console.log('Setting right panel min-height to:', leftHeight + 'px'); |
|
|
|
|
|
rightPanel.style.minHeight = leftHeight + 'px'; |
|
|
rightPanel.style.display = 'flex'; |
|
|
rightPanel.style.flexDirection = 'column'; |
|
|
rightPanel.style.justifyContent = 'center'; |
|
|
|
|
|
console.log('Applied min-height and flex centering'); |
|
|
return true; // Success |
|
|
} else { |
|
|
console.log('No height adjustment needed (diff <= 0)'); |
|
|
return true; // Success |
|
|
} |
|
|
} else { |
|
|
console.log('Panels not ready yet'); |
|
|
return false; // Not ready |
|
|
} |
|
|
} |
|
|
|
|
|
// Check every 50ms for 3 seconds to catch multiple height changes |
|
|
let attempts = 0; |
|
|
const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes |
|
|
const checkInterval = setInterval(() => { |
|
|
attempts++; |
|
|
console.log('Attempt', attempts, 'to match heights'); |
|
|
|
|
|
matchPanelHeights(); // Always try, don't stop early |
|
|
|
|
|
if (attempts >= maxAttempts) { |
|
|
console.log('Finished checking after 3 seconds'); |
|
|
clearInterval(checkInterval); |
|
|
} |
|
|
}, 50); // Check every 50ms |
|
|
} |
|
|
""" |
|
|
) |
|
|
|
|
|
guide_button.click( |
|
|
fn=show_species_guide, |
|
|
outputs=[intro_content, species_guide_content, back_button, guide_button] |
|
|
) |
|
|
|
|
|
back_button.click( |
|
|
fn=show_intro, |
|
|
outputs=[intro_content, species_guide_content, back_button, guide_button] |
|
|
) |
|
|
|
|
|
|
|
|
prob_help_button.click( |
|
|
fn=show_prob_help, |
|
|
outputs=[prob_help_popup] |
|
|
) |
|
|
|
|
|
prob_help_close.click( |
|
|
fn=hide_prob_help, |
|
|
outputs=[prob_help_popup] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc_help_close.click( |
|
|
fn=hide_acc_help, |
|
|
outputs=[acc_help_popup] |
|
|
) |
|
|
|
|
|
selection_help_button.click( |
|
|
fn=show_selection_help, |
|
|
outputs=[selection_help_popup] |
|
|
) |
|
|
|
|
|
selection_help_close.click( |
|
|
fn=hide_selection_help, |
|
|
outputs=[selection_help_popup] |
|
|
) |
|
|
|
|
|
|
|
|
reveal_accuracy_button.click( |
|
|
fn=reveal_accuracies, |
|
|
outputs=[accuracy_plot, hidden_group] |
|
|
) |
|
|
|
|
|
|
|
|
view_guide_button.click( |
|
|
fn=show_species_guide_popup, |
|
|
outputs=[species_guide_popup] |
|
|
) |
|
|
|
|
|
species_guide_close.click( |
|
|
fn=hide_species_guide_popup, |
|
|
outputs=[species_guide_popup] |
|
|
) |
|
|
|
|
|
for btn in species_buttons: |
|
|
btn.click( |
|
|
fn=check_answer, |
|
|
inputs=[gr.State(btn.value)], |
|
|
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot] |
|
|
) |
|
|
|
|
|
idk_button.click( |
|
|
fn=check_answer, |
|
|
inputs=[gr.State("I don't know")], |
|
|
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot] |
|
|
) |
|
|
|
|
|
|
|
|
if DEBUG_MODE: |
|
|
delete_button.click( |
|
|
fn=delete_current_image, |
|
|
outputs=[result_display, image_display, model_predictions_display, prob_plot, accuracy_plot, status_with_help] |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
lambda: None, |
|
|
outputs=[], |
|
|
js=""" |
|
|
() => { |
|
|
// Handle inline help button clicks |
|
|
setTimeout(() => { |
|
|
document.addEventListener('click', function(e) { |
|
|
if (e.target && e.target.classList.contains('inline-help-btn')) { |
|
|
e.preventDefault(); |
|
|
e.stopPropagation(); |
|
|
const hiddenBtn = document.getElementById('hidden-selection-help-btn'); |
|
|
if (hiddenBtn) { |
|
|
hiddenBtn.click(); |
|
|
} |
|
|
} |
|
|
}); |
|
|
}, 100); |
|
|
|
|
|
// Dynamic image sizing (NEW VERSION) |
|
|
console.log('=== IMAGE SIZING V2 LOADED ==='); |
|
|
|
|
|
function adjustImageSize() { |
|
|
const imageContainer = document.getElementById('main-image-display'); |
|
|
if (!imageContainer) { |
|
|
console.log('[V2] Image container not found'); |
|
|
return false; |
|
|
} |
|
|
|
|
|
const viewportHeight = window.innerHeight; |
|
|
const docHeight = document.documentElement.scrollHeight; |
|
|
const currentImageHeight = imageContainer.offsetHeight; |
|
|
|
|
|
// Calculate how much we're overflowing |
|
|
const overflow = docHeight - viewportHeight; |
|
|
|
|
|
// If we're not overflowing, increase image size |
|
|
// If we are overflowing, decrease image size by the overflow amount |
|
|
const adjustment = -overflow - 30; // Keep padding below bottom button |
|
|
const targetHeight = currentImageHeight + adjustment; |
|
|
|
|
|
console.log('[V2] viewport:', viewportHeight, 'docHeight:', docHeight, 'currentImg:', currentImageHeight, 'overflow:', overflow, 'target:', targetHeight); |
|
|
|
|
|
// Only apply if reasonable |
|
|
if (targetHeight > 300 && targetHeight < viewportHeight - 100) { |
|
|
imageContainer.style.height = targetHeight + 'px'; |
|
|
imageContainer.style.maxHeight = targetHeight + 'px'; |
|
|
console.log('[V2] Set image height to:', targetHeight + 'px'); |
|
|
return true; |
|
|
} |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
// Run after initial load |
|
|
setTimeout(adjustImageSize, 500); |
|
|
|
|
|
// Run periodically for first 5 seconds to catch layout changes |
|
|
let attempts = 0; |
|
|
const interval = setInterval(() => { |
|
|
attempts++; |
|
|
adjustImageSize(); |
|
|
if (attempts >= 50) { // 50 * 100ms = 5 seconds |
|
|
clearInterval(interval); |
|
|
} |
|
|
}, 100); |
|
|
|
|
|
// Re-adjust on window resize |
|
|
window.addEventListener('resize', adjustImageSize); |
|
|
} |
|
|
""", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
|
|
|
|
|
|
allowed_paths=["/"], |
|
|
) |
|
|
|