Spaces:
Sleeping
Sleeping
import hashlib | |
import json | |
import pickle | |
from datetime import datetime | |
from pathlib import Path | |
import gradio as gr | |
import pandas as pd | |
import plotly.graph_objects as go | |
from datasets import load_dataset | |
from tqdm import tqdm | |
# Cache configuration | |
global CACHE_DIR | |
global TASKS_INDEX_FILE | |
global TASK_DATA_DIR | |
global DATASET_DATA_DIR | |
global METRICS_INDEX_FILE | |
CACHE_DIR = Path("./pwc_cache") | |
CACHE_DIR.mkdir(exist_ok=True) | |
# Directory structure for disk-based storage | |
TASKS_INDEX_FILE = CACHE_DIR / "tasks_index.json" # Small JSON file with task list | |
TASK_DATA_DIR = CACHE_DIR / "task_data" # Directory for individual task files | |
DATASET_DATA_DIR = CACHE_DIR / "dataset_data" # Directory for individual dataset files | |
METRICS_INDEX_FILE = CACHE_DIR / "metrics_index.json" # Metrics metadata | |
# Create directories | |
TASK_DATA_DIR.mkdir(exist_ok=True) | |
DATASET_DATA_DIR.mkdir(exist_ok=True) | |
def sanitize_filename(name): | |
"""Convert a string to a safe filename.""" | |
# Replace problematic characters with underscores | |
safe_name = name.replace('/', '_').replace('\\', '_').replace(':', '_') | |
safe_name = safe_name.replace('*', '_').replace('?', '_').replace('"', '_') | |
safe_name = safe_name.replace('<', '_').replace('>', '_').replace('|', '_') | |
safe_name = safe_name.replace(' ', '_').replace('.', '_') | |
# Remove multiple underscores and trim | |
safe_name = '_'.join(filter(None, safe_name.split('_'))) | |
# Limit length to avoid filesystem issues | |
if len(safe_name) > 200: | |
# If too long, use first 150 chars + hash of full name | |
safe_name = safe_name[:150] + '_' + hashlib.md5(name.encode()).hexdigest()[:8] | |
return safe_name | |
def get_task_filename(task): | |
"""Generate a safe filename for a task.""" | |
safe_name = sanitize_filename(task) | |
return TASK_DATA_DIR / f"task_{safe_name}.pkl" | |
def get_dataset_filename(task, dataset_name): | |
"""Generate a safe filename for a dataset.""" | |
safe_task = sanitize_filename(task) | |
safe_dataset = sanitize_filename(dataset_name) | |
# Include both task and dataset in filename for clarity | |
filename = f"data_{safe_task}_{safe_dataset}.pkl" | |
# If combined name is too long, shorten it | |
if len(filename) > 255: | |
# Use shorter version with hash | |
filename = f"data_{safe_task[:50]}_{safe_dataset[:50]}_{hashlib.md5(f'{task}||{dataset_name}'.encode()).hexdigest()[:8]}.pkl" | |
return DATASET_DATA_DIR / filename | |
def cache_exists(): | |
"""Check if cache structure exists.""" | |
print(f"{TASKS_INDEX_FILE =}") | |
print(f"{METRICS_INDEX_FILE =}") | |
print(f"{TASKS_INDEX_FILE.exists() =}") | |
print(f"{METRICS_INDEX_FILE.exists() =}") | |
return TASKS_INDEX_FILE.exists() and METRICS_INDEX_FILE.exists() | |
def build_disk_based_cache(): | |
"""Build cache with minimal memory usage - process dataset in streaming fashion.""" | |
import os | |
print("Michael test", os.path.isdir("./pwc_cache")) | |
print("=" * 60) | |
print("=" * 60) | |
print("Building disk-based cache (one-time operation)...") | |
print("=" * 60) | |
# Initialize tracking structures (kept small) | |
tasks_set = set() | |
metrics_index = {} | |
print("\n[1/4] Streaming dataset and building cache...") | |
# Load dataset in streaming mode to save memory | |
ds = load_dataset("pwc-archive/evaluation-tables", split="train", streaming=False) | |
total_items = len(ds) | |
processed_count = 0 | |
dataset_count = 0 | |
for idx, item in tqdm(enumerate(ds), total=total_items): | |
# Progress indicator | |
task = item['task'] | |
if not task: | |
continue | |
tasks_set.add(task) | |
# Load existing task data from disk or create new | |
task_file = get_task_filename(task) | |
if task_file.exists(): | |
with open(task_file, 'rb') as f: | |
task_data = pickle.load(f) | |
else: | |
task_data = { | |
'categories': set(), | |
'datasets': set(), | |
'date_range': {'min': None, 'max': None} | |
} | |
# Update task data | |
if item['categories']: | |
task_data['categories'].update(item['categories']) | |
# Process datasets | |
if item['datasets']: | |
for dataset in item['datasets']: | |
if not isinstance(dataset, dict) or 'dataset' not in dataset: | |
continue | |
dataset_name = dataset['dataset'] | |
dataset_file = get_dataset_filename(task, dataset_name) | |
# Skip if already processed | |
if dataset_file.exists(): | |
task_data['datasets'].add(dataset_name) | |
continue | |
task_data['datasets'].add(dataset_name) | |
# Process SOTA data | |
if 'sota' not in dataset or 'rows' not in dataset['sota']: | |
continue | |
models_data = [] | |
for row in dataset['sota']['rows']: | |
if not isinstance(row, dict): | |
continue | |
model_name = row.get('model_name', 'Unknown Model') | |
# Extract metrics | |
metrics = {} | |
if 'metrics' in row and isinstance(row['metrics'], dict): | |
for metric_name, metric_value in row['metrics'].items(): | |
if metric_value is not None: | |
metrics[metric_name] = metric_value | |
# Track metric metadata | |
if metric_name not in metrics_index: | |
metrics_index[metric_name] = { | |
'count': 0, | |
'is_lower_better': any(kw in metric_name.lower() | |
for kw in ['error', 'loss', 'time', 'cost']) | |
} | |
metrics_index[metric_name]['count'] += 1 | |
# Parse date | |
paper_date = row.get('paper_date') | |
try: | |
if paper_date and isinstance(paper_date, str): | |
release_date = pd.to_datetime(paper_date) | |
else: | |
release_date = pd.to_datetime('2020-01-01') | |
except: | |
release_date = pd.to_datetime('2020-01-01') | |
# Update date range | |
if task_data['date_range']['min'] is None or release_date < task_data['date_range']['min']: | |
task_data['date_range']['min'] = release_date | |
if task_data['date_range']['max'] is None or release_date > task_data['date_range']['max']: | |
task_data['date_range']['max'] = release_date | |
# Build model entry | |
model_entry = { | |
'model_name': model_name, | |
'release_date': release_date, | |
'paper_date': row.get('paper_date', ''), # Store raw paper_date for dynamic parsing | |
'paper_url': row.get('paper_url', ''), | |
'paper_title': row.get('paper_title', ''), | |
'code_url': row.get('code_links', [''])[0] if row.get('code_links') else '', | |
**metrics | |
} | |
models_data.append(model_entry) | |
if models_data: | |
df = pd.DataFrame(models_data) | |
df = df.sort_values('release_date') | |
# Save dataset to its own file | |
with open(dataset_file, 'wb') as f: | |
pickle.dump(df, f, protocol=pickle.HIGHEST_PROTOCOL) | |
dataset_count += 1 | |
# Clear DataFrame from memory | |
del df | |
del models_data | |
# Save updated task data back to disk | |
with open(task_file, 'wb') as f: | |
# Convert sets to lists for serialization | |
task_data_to_save = { | |
'categories': sorted(list(task_data['categories'])), | |
'datasets': sorted(list(task_data['datasets'])), | |
'date_range': task_data['date_range'] | |
} | |
pickle.dump(task_data_to_save, f, protocol=pickle.HIGHEST_PROTOCOL) | |
# Clear task data from memory | |
del task_data | |
processed_count += 1 | |
print(f"\nβ Processed {len(tasks_set)} tasks and {dataset_count} datasets") | |
print("\n[2/4] Saving index files...") | |
# Save tasks index (small file) | |
tasks_list = sorted(list(tasks_set)) | |
with open(TASKS_INDEX_FILE, 'w') as f: | |
json.dump(tasks_list, f) | |
print(f" β Saved tasks index ({len(tasks_list)} tasks)") | |
# Save metrics index | |
with open(METRICS_INDEX_FILE, 'w') as f: | |
json.dump(metrics_index, f, indent=2) | |
print(f" β Saved metrics index ({len(metrics_index)} metrics)") | |
print("\n[3/4] Calculating cache statistics...") | |
# Calculate total cache size | |
total_size = 0 | |
for file in TASK_DATA_DIR.glob("*.pkl"): | |
total_size += file.stat().st_size | |
for file in DATASET_DATA_DIR.glob("*.pkl"): | |
total_size += file.stat().st_size | |
print(f" β Total cache size: {total_size / 1024 / 1024:.1f} MB") | |
print(f" β Task files: {len(list(TASK_DATA_DIR.glob('*.pkl')))}") | |
print(f" β Dataset files: {len(list(DATASET_DATA_DIR.glob('*.pkl')))}") | |
print("\n[4/4] Cache building complete!") | |
print("=" * 60) | |
return tasks_list | |
def load_tasks_index(): | |
"""Load just the task list from disk.""" | |
with open(TASKS_INDEX_FILE, 'r') as f: | |
return json.load(f) | |
def load_task_data(task): | |
"""Load data for a specific task from disk.""" | |
task_file = get_task_filename(task) | |
if task_file.exists(): | |
with open(task_file, 'rb') as f: | |
return pickle.load(f) | |
return None | |
def load_dataset_data(task, dataset_name): | |
"""Load a specific dataset from disk.""" | |
dataset_file = get_dataset_filename(task, dataset_name) | |
if dataset_file.exists(): | |
with open(dataset_file, 'rb') as f: | |
return pickle.load(f) | |
return pd.DataFrame() | |
def load_metrics_index(): | |
"""Load metrics index from disk.""" | |
if METRICS_INDEX_FILE.exists(): | |
with open(METRICS_INDEX_FILE, 'r') as f: | |
return json.load(f) | |
return {} | |
# Initialize - build cache if doesn't exist | |
if cache_exists(): | |
print("Loading task index from disk...") | |
TASKS = load_tasks_index() | |
print(f"β Loaded {len(TASKS)} tasks") | |
else: | |
TASKS = build_disk_based_cache() | |
# Load metrics index once (it's small) | |
METRICS_INDEX = load_metrics_index() | |
# Memory-efficient accessor functions | |
def get_tasks(): | |
"""Get all tasks from index.""" | |
return TASKS | |
def get_task_data(task): | |
"""Load task data from disk on-demand.""" | |
return load_task_data(task) | |
def get_categories(task): | |
"""Get categories for a task (loads from disk).""" | |
task_data = get_task_data(task) | |
return task_data['categories'] if task_data else [] | |
def get_datasets_for_task(task): | |
"""Get datasets for a task (loads from disk).""" | |
task_data = get_task_data(task) | |
return task_data['datasets'] if task_data else [] | |
def get_cached_model_data(task, dataset_name): | |
"""Load dataset from disk on-demand.""" | |
return load_dataset_data(task, dataset_name) | |
def parse_paper_date(paper_date, paper_title="", paper_url=""): | |
"""Parse paper date with improved fallback strategies.""" | |
import re | |
# Try to parse the raw paper_date if available | |
if paper_date and isinstance(paper_date, str) and paper_date.strip(): | |
try: | |
# Try common date formats | |
date_formats = [ | |
'%Y-%m-%d', | |
'%Y/%m/%d', | |
'%d-%m-%Y', | |
'%d/%m/%Y', | |
'%Y-%m', | |
'%Y/%m', | |
'%Y' | |
] | |
for fmt in date_formats: | |
try: | |
return pd.to_datetime(paper_date.strip(), format=fmt) | |
except: | |
continue | |
# Try pandas automatic parsing | |
return pd.to_datetime(paper_date.strip()) | |
except: | |
pass | |
# Fallback: try to extract year from paper title or URL | |
year_pattern = r'\b(19[5-9]\d|20[0-9]\d)\b' # Match 1950-2099 | |
# Look for year in paper title | |
if paper_title: | |
years = re.findall(year_pattern, str(paper_title)) | |
if years: | |
try: | |
year = max(years) # Use the latest year found | |
return pd.to_datetime(f'{year}-01-01') | |
except: | |
pass | |
# Look for year in paper URL | |
if paper_url: | |
years = re.findall(year_pattern, str(paper_url)) | |
if years: | |
try: | |
year = max(years) # Use the latest year found | |
return pd.to_datetime(f'{year}-01-01') | |
except: | |
pass | |
# Final fallback: return None instead of a default year | |
return None | |
def get_task_statistics(task): | |
"""Get statistics about a task.""" | |
return {} | |
def create_sota_plot(df, metric): | |
"""Create a plot showing model performance evolution over time. | |
Args: | |
df: DataFrame with model data | |
metric: Metric name to plot on y-axis | |
""" | |
if df.empty or metric not in df.columns: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No data available for this metric", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=20) | |
) | |
fig.update_layout( | |
title="No Data Available", | |
height=600, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return fig | |
# Remove rows where the metric is NaN | |
df_clean = df.dropna(subset=[metric]).copy() | |
if df_clean.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No valid data points for this metric", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=20) | |
) | |
fig.update_layout( | |
title="No Data Available", | |
height=600, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return fig | |
# Convert metric column to numeric, handling any string values | |
try: | |
df_clean[metric] = pd.to_numeric( | |
df_clean[metric].apply(lambda x: x.strip()[:-1] if isinstance(x, str) and x.strip().endswith("%") else x), | |
errors='coerce') | |
# Remove any rows that couldn't be converted to numeric | |
df_clean = df_clean.dropna(subset=[metric]) | |
if df_clean.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"No numeric data available for metric: {metric}", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=20) | |
) | |
fig.update_layout( | |
title="No Numeric Data Available", | |
height=600, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return fig | |
except Exception as e: | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Error processing metric data: {str(e)}", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=16) | |
) | |
fig.update_layout( | |
title="Data Processing Error", | |
height=600, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return fig | |
# Recalculate release dates dynamically from raw paper_date if available | |
df_processed = df_clean.copy() | |
if 'paper_date' in df_processed.columns: | |
# Parse dates dynamically using improved logic | |
df_processed['dynamic_release_date'] = df_processed.apply( | |
lambda row: parse_paper_date( | |
row.get('paper_date', ''), | |
row.get('paper_title', ''), | |
row.get('paper_url', '') | |
), axis=1 | |
) | |
# Use dynamic dates if available, otherwise fallback to original release_date | |
df_processed['final_release_date'] = df_processed['dynamic_release_date'].fillna(df_processed['release_date']) | |
else: | |
# If no paper_date column, use existing release_date | |
df_processed['final_release_date'] = df_processed['release_date'] | |
# Filter out rows with no valid date | |
df_with_dates = df_processed[df_processed['final_release_date'].notna()].copy() | |
if df_with_dates.empty: | |
# If no valid dates, return empty plot | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No valid dates available for this dataset", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=20) | |
) | |
fig.update_layout( | |
title="No Date Data Available", | |
height=600, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return fig | |
# Sort by final release date | |
df_sorted = df_with_dates.sort_values('final_release_date').copy() | |
# Check if metric is lower-better | |
is_lower_better = False | |
if metric in METRICS_INDEX: | |
is_lower_better = METRICS_INDEX[metric].get('is_lower_better', False) | |
else: | |
is_lower_better = any(keyword in metric.lower() for keyword in ['error', 'loss', 'time', 'cost']) | |
if is_lower_better: | |
df_sorted['cumulative_best'] = df_sorted[metric].cummin() | |
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] | |
else: | |
df_sorted['cumulative_best'] = df_sorted[metric].cummax() | |
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] | |
# Get SOTA models | |
sota_df = df_sorted[df_sorted['is_sota']].copy() | |
# Use the dynamically calculated dates for x-axis | |
x_values = df_sorted['final_release_date'] | |
x_axis_title = 'Release Date' | |
# Create the plot | |
fig = go.Figure() | |
# Add all models as scatter points | |
fig.add_trace(go.Scatter( | |
x=x_values, | |
y=df_sorted[metric], | |
mode='markers', | |
name='All models', | |
marker=dict( | |
color=['#00CED1' if is_sota else 'lightgray' | |
for is_sota in df_sorted['is_sota']], | |
size=8, | |
opacity=0.7 | |
), | |
text=df_sorted['model_name'], | |
customdata=df_sorted[['paper_title', 'paper_url', 'code_url']], | |
hovertemplate='<b>%{text}</b><br>' + | |
f'{metric}: %{{y:.4f}}<br>' + | |
'Date: %{x}<br>' + | |
'Paper: %{customdata[0]}<br>' + | |
'<extra></extra>' | |
)) | |
# Add SOTA line | |
fig.add_trace(go.Scatter( | |
x=x_values, | |
y=df_sorted['cumulative_best'], | |
mode='lines', | |
name=f'SOTA (cumulative {"min" if is_lower_better else "max"})', | |
line=dict(color='#00CED1', width=2, dash='solid'), | |
hovertemplate=f'SOTA {metric}: %{{y:.4f}}<br>{x_axis_title}: %{{x}}<extra></extra>' | |
)) | |
# Add labels for SOTA models | |
if not sota_df.empty: | |
# Calculate dynamic offset based on data range | |
y_range = df_sorted[metric].max() - df_sorted[metric].min() | |
# Use a percentage of the range for offset, with minimum and maximum bounds | |
if y_range > 0: | |
base_offset = y_range * 0.03 # 3% of the data range | |
# Ensure minimum offset for readability and maximum to prevent excessive spacing | |
label_offset = max(y_range * 0.01, min(base_offset, y_range * 0.08)) | |
else: | |
# Fallback for when all values are the same | |
label_offset = 1 | |
# Track label positions to prevent overlaps | |
previous_labels = [] | |
# For date-based x-axis, use date separation | |
try: | |
date_range = (df_sorted['final_release_date'].max() - df_sorted['final_release_date'].min()).days | |
min_separation = max(30, date_range * 0.05) # Minimum 30 days or 5% of range | |
except (TypeError, AttributeError): | |
# Fallback if date calculation fails | |
min_separation = 30 | |
for i, (_, row) in enumerate(sota_df.iterrows()): | |
# Determine base label position based on metric type | |
if is_lower_better: | |
# For lower-better metrics, place label above the point (negative ay) | |
base_ay_offset = -label_offset | |
base_yshift = -8 | |
alternate_multiplier = -1 | |
else: | |
# For higher-better metrics, place label below the point (positive ay) | |
base_ay_offset = label_offset | |
base_yshift = 8 | |
alternate_multiplier = 1 | |
# Check for collision with previous labels | |
current_x = row['final_release_date'] | |
collision_detected = False | |
for prev_x, prev_ay in previous_labels: | |
try: | |
x_diff = abs((current_x - prev_x).days) | |
if x_diff < min_separation: | |
collision_detected = True | |
break | |
except (TypeError, AttributeError): | |
# Skip collision detection if calculation fails | |
continue | |
# Adjust position if collision detected | |
if collision_detected: | |
# Alternate the label position (above/below) to avoid overlap | |
ay_offset = base_ay_offset + (alternate_multiplier * label_offset * 0.7 * (i % 2)) | |
yshift = base_yshift + (alternate_multiplier * 12 * (i % 2)) | |
else: | |
ay_offset = base_ay_offset | |
yshift = base_yshift | |
# Add the annotation | |
fig.add_annotation( | |
x=current_x, | |
y=row[metric], | |
text=row['model_name'][:25] + '...' if len(row['model_name']) > 25 else row['model_name'], | |
showarrow=True, | |
arrowhead=2, | |
arrowsize=1, | |
arrowwidth=1, | |
arrowcolor='#00CED1', # Match the SOTA line color | |
ax=0, | |
ay=ay_offset, # Dynamic offset based on data range and collision detection | |
yshift=yshift, # Fine-tune positioning | |
font=dict(size=8, color='#333333'), | |
bgcolor='rgba(255, 255, 255, 0.9)', # Semi-transparent background | |
borderwidth=0 # Remove border | |
) | |
# Track this label position | |
previous_labels.append((current_x, ay_offset)) | |
# Update layout | |
fig.update_layout( | |
title=f'SOTA Evolution: {metric}', | |
xaxis_title=x_axis_title, | |
yaxis_title=metric, | |
xaxis=dict(showgrid=True, gridcolor='lightgray'), | |
yaxis=dict(showgrid=True, gridcolor='lightgray'), | |
plot_bgcolor='white', | |
paper_bgcolor='white', | |
height=600, | |
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
hovermode='closest' | |
) | |
# Clear the DataFrame from memory after plotting | |
del df_clean | |
del df_sorted | |
del sota_df | |
return fig | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π Papers with Code - SOTA Evolution Visualizer") | |
gr.Markdown( | |
"Navigate through ML tasks and datasets to visualize the evolution of state-of-the-art models over time.") | |
gr.Markdown("*Optimized for low memory usage - data is loaded on-demand from disk*") | |
# Status | |
with gr.Row(): | |
gr.Markdown(f""" | |
<div style="background-color: #f0f9ff; border-left: 4px solid #00CED1; padding: 10px; margin: 10px 0;"> | |
<b>πΎ Disk-Based Storage Active</b><br> | |
β’ <b>{len(TASKS)}</b> tasks indexed<br> | |
β’ <b>{len(METRICS_INDEX)}</b> unique metrics tracked<br> | |
β’ Data loaded on-demand to minimize RAM usage | |
</div> | |
""") | |
# State variables | |
current_df = gr.State(pd.DataFrame()) | |
current_task = gr.State(None) | |
# Navigation dropdowns | |
with gr.Row(): | |
task_dropdown = gr.Dropdown( | |
choices=get_tasks(), | |
label="Select Task", | |
interactive=True | |
) | |
category_dropdown = gr.Dropdown( | |
choices=[], | |
label="Categories (info only)", | |
interactive=False | |
) | |
with gr.Row(): | |
dataset_dropdown = gr.Dropdown( | |
choices=[], | |
label="Select Dataset", | |
interactive=True | |
) | |
metric_dropdown = gr.Dropdown( | |
choices=[], | |
label="Select Metric", | |
interactive=True | |
) | |
# Info display | |
info_text = gr.Markdown("π Please select a task to begin") | |
# Plot | |
plot = gr.Plot(label="SOTA Evolution") | |
# Data display | |
with gr.Row(): | |
show_data_btn = gr.Button("π Show/Hide Model Data") | |
export_btn = gr.Button("πΎ Export Current Data (CSV)") | |
clear_memory_btn = gr.Button("π§Ή Clear Memory", variant="secondary") | |
df_display = gr.Dataframe( | |
label="Model Data", | |
visible=False | |
) | |
# Update functions | |
def update_task_selection(task): | |
"""Update dropdowns when task is selected.""" | |
if not task: | |
return [], [], [], "π Please select a task to begin", pd.DataFrame(), None, None | |
# Load task data from disk | |
categories = get_categories(task) | |
datasets = get_datasets_for_task(task) | |
info = f"### π **Task:** {task}\n" | |
if categories: | |
info += f"- **Categories:** {', '.join(categories[:3])}{'...' if len(categories) > 3 else ''} ({len(categories)} total)\n" | |
return ( | |
gr.Dropdown(choices=categories, value=categories[0] if categories else None), | |
gr.Dropdown(choices=datasets, value=None), | |
gr.Dropdown(choices=[], value=None), | |
info, | |
pd.DataFrame(), | |
None, | |
task # Store current task | |
) | |
def update_dataset_selection(task, dataset_name): | |
"""Update when dataset is selected - loads from disk.""" | |
if not task or not dataset_name: | |
return [], "", pd.DataFrame(), None | |
# Load dataset from disk | |
df = get_cached_model_data(task, dataset_name) | |
if df.empty: | |
return [], f"β οΈ No models found for dataset: {dataset_name}", df, None | |
# Get metric columns | |
exclude_cols = ['model_name', 'release_date', 'paper_date', 'paper_url', 'paper_title', 'code_url'] | |
metric_cols = [col for col in df.columns if col not in exclude_cols] | |
info = f"### π **Dataset:** {dataset_name}\n" | |
info += f"- **Models:** {len(df)} models\n" | |
info += f"- **Metrics:** {len(metric_cols)} metrics available\n" | |
if not df.empty: | |
info += f"- **Date Range:** {df['release_date'].min().strftime('%Y-%m-%d')} to {df['release_date'].max().strftime('%Y-%m-%d')}\n" | |
if metric_cols: | |
info += f"- **Available Metrics:** {', '.join(metric_cols[:5])}{'...' if len(metric_cols) > 5 else ''}" | |
return ( | |
gr.Dropdown(choices=metric_cols, value=metric_cols[0] if metric_cols else None), | |
info, | |
df, | |
None | |
) | |
def update_plot(df, metric): | |
"""Update plot when metric is selected.""" | |
if df.empty or not metric: | |
return None | |
plot_result = create_sota_plot(df, metric) | |
return plot_result | |
def toggle_dataframe(df): | |
"""Toggle dataframe visibility.""" | |
if df.empty: | |
return gr.Dataframe(value=pd.DataFrame(), visible=False) | |
# Show relevant columns | |
display_cols = ['model_name', 'release_date'] + [col for col in df.columns | |
if col not in ['model_name', 'release_date', 'paper_date', | |
'paper_url', | |
'paper_title', 'code_url']] | |
display_df = df[display_cols].copy() | |
display_df['release_date'] = display_df['release_date'].dt.strftime('%Y-%m-%d') | |
return gr.Dataframe(value=display_df, visible=True) | |
def export_data(df): | |
"""Export current dataframe to CSV.""" | |
if df.empty: | |
return "β οΈ No data to export" | |
filename = f"sota_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
df.to_csv(filename, index=False) | |
return f"β Data exported to {filename} ({len(df)} models)" | |
def clear_memory(): | |
"""Clear memory by forcing garbage collection.""" | |
import gc | |
gc.collect() | |
return "β Memory cleared" | |
# Event handlers | |
task_dropdown.change( | |
fn=update_task_selection, | |
inputs=task_dropdown, | |
outputs=[category_dropdown, dataset_dropdown, | |
metric_dropdown, info_text, current_df, plot, current_task] | |
) | |
dataset_dropdown.change( | |
fn=update_dataset_selection, | |
inputs=[task_dropdown, dataset_dropdown], | |
outputs=[metric_dropdown, info_text, current_df, plot] | |
) | |
metric_dropdown.change( | |
fn=update_plot, | |
inputs=[current_df, metric_dropdown], | |
outputs=plot | |
) | |
show_data_btn.click( | |
fn=toggle_dataframe, | |
inputs=current_df, | |
outputs=df_display | |
) | |
export_btn.click( | |
fn=export_data, | |
inputs=current_df, | |
outputs=info_text | |
) | |
clear_memory_btn.click( | |
fn=clear_memory, | |
inputs=[], | |
outputs=info_text | |
) | |
gr.Markdown(""" | |
--- | |
### π How to Use | |
1. **Select a Task** from the first dropdown | |
2. **Select a Dataset** to analyze | |
3. **Select a Metric** to visualize | |
4. The plot shows SOTA model evolution over time with dynamically calculated dates | |
### πΎ Memory Optimization | |
- Data is stored on disk and loaded on-demand | |
- Only the current task and dataset are kept in memory | |
- Use "Clear Memory" button if needed | |
- Infinite disk space is utilized for permanent caching | |
### π¨ Plot Features | |
- **π΅ Cyan dots**: SOTA models when released | |
- **βͺ Gray dots**: Other models | |
- **π Cyan line**: SOTA progression | |
- **π Hover**: View model details | |
- **π·οΈ Smart Labels**: SOTA model labels positioned close to the line with intelligent collision detection | |
""") | |
def test_sota_label_positioning(): | |
"""Test function to validate SOTA label positioning improvements.""" | |
print("π§ͺ Testing SOTA label positioning...") | |
# Create sample data for testing | |
import pandas as pd | |
from datetime import datetime | |
# Test data with different metric types (including all required columns) | |
test_data = { | |
'model_name': ['Model A', 'Model B', 'Model C', 'Model D'], | |
'release_date': [ | |
datetime(2020, 1, 1), | |
datetime(2020, 6, 1), | |
datetime(2021, 1, 1), | |
datetime(2021, 6, 1) | |
], | |
'paper_title': ['Paper A', 'Paper B', 'Paper C', 'Paper D'], | |
'paper_url': ['http://example.com/a', 'http://example.com/b', 'http://example.com/c', 'http://example.com/d'], | |
'code_url': ['http://github.com/a', 'http://github.com/b', 'http://github.com/c', 'http://github.com/d'], | |
'accuracy': [0.85, 0.87, 0.90, 0.92], # Higher-better metric | |
'error_rate': [0.15, 0.13, 0.10, 0.08] # Lower-better metric | |
} | |
df_test = pd.DataFrame(test_data) | |
# Test with higher-better metric (accuracy) | |
print(" Testing with higher-better metric (accuracy)...") | |
try: | |
fig1 = create_sota_plot(df_test, 'accuracy') | |
print(" β Higher-better metric test passed") | |
except Exception as e: | |
print(f" β Higher-better metric test failed: {e}") | |
# Test with lower-better metric (error_rate) | |
print(" Testing with lower-better metric (error_rate)...") | |
try: | |
fig2 = create_sota_plot(df_test, 'error_rate') | |
print(" β Lower-better metric test passed") | |
except Exception as e: | |
print(f" β Lower-better metric test failed: {e}") | |
# Test with empty data | |
print(" Testing with empty dataframe...") | |
try: | |
fig3 = create_sota_plot(pd.DataFrame(), 'test_metric') | |
print(" β Empty data test passed") | |
except Exception as e: | |
print(f" β Empty data test failed: {e}") | |
# Test with string metric data (should handle gracefully) | |
print(" Testing with string metric data...") | |
try: | |
df_test_string = df_test.copy() | |
df_test_string['string_metric'] = ['low', 'medium', 'high', 'very_high'] | |
fig4 = create_sota_plot(df_test_string, 'string_metric') | |
print(" β String metric test passed (handled gracefully)") | |
except Exception as e: | |
print(f" β String metric test failed: {e}") | |
# Test with mixed numeric/string data | |
print(" Testing with mixed data types...") | |
try: | |
df_test_mixed = df_test.copy() | |
df_test_mixed['mixed_metric'] = [0.85, 'N/A', 0.90, 0.92] | |
fig5 = create_sota_plot(df_test_mixed, 'mixed_metric') | |
print(" β Mixed data test passed") | |
except Exception as e: | |
print(f" β Mixed data test failed: {e}") | |
# Test with paper_date parsing | |
print(" Testing with paper_date column...") | |
try: | |
df_test_dates = df_test.copy() | |
df_test_dates['paper_date'] = ['2015-03-15', '2018-invalid', '2021-12-01', '2022'] | |
fig6 = create_sota_plot(df_test_dates, 'accuracy') | |
print(" β Paper date parsing test passed") | |
except Exception as e: | |
print(f" β Paper date parsing test failed: {e}") | |
print("π SOTA label positioning tests completed!") | |
return True | |
demo.launch() |