Michael Shekasta
adding files
a8592c4
import hashlib
import json
import pickle
from datetime import datetime
from pathlib import Path
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from datasets import load_dataset
from tqdm import tqdm
# Cache configuration
global CACHE_DIR
global TASKS_INDEX_FILE
global TASK_DATA_DIR
global DATASET_DATA_DIR
global METRICS_INDEX_FILE
CACHE_DIR = Path("./pwc_cache")
CACHE_DIR.mkdir(exist_ok=True)
# Directory structure for disk-based storage
TASKS_INDEX_FILE = CACHE_DIR / "tasks_index.json" # Small JSON file with task list
TASK_DATA_DIR = CACHE_DIR / "task_data" # Directory for individual task files
DATASET_DATA_DIR = CACHE_DIR / "dataset_data" # Directory for individual dataset files
METRICS_INDEX_FILE = CACHE_DIR / "metrics_index.json" # Metrics metadata
# Create directories
TASK_DATA_DIR.mkdir(exist_ok=True)
DATASET_DATA_DIR.mkdir(exist_ok=True)
def sanitize_filename(name):
"""Convert a string to a safe filename."""
# Replace problematic characters with underscores
safe_name = name.replace('/', '_').replace('\\', '_').replace(':', '_')
safe_name = safe_name.replace('*', '_').replace('?', '_').replace('"', '_')
safe_name = safe_name.replace('<', '_').replace('>', '_').replace('|', '_')
safe_name = safe_name.replace(' ', '_').replace('.', '_')
# Remove multiple underscores and trim
safe_name = '_'.join(filter(None, safe_name.split('_')))
# Limit length to avoid filesystem issues
if len(safe_name) > 200:
# If too long, use first 150 chars + hash of full name
safe_name = safe_name[:150] + '_' + hashlib.md5(name.encode()).hexdigest()[:8]
return safe_name
def get_task_filename(task):
"""Generate a safe filename for a task."""
safe_name = sanitize_filename(task)
return TASK_DATA_DIR / f"task_{safe_name}.pkl"
def get_dataset_filename(task, dataset_name):
"""Generate a safe filename for a dataset."""
safe_task = sanitize_filename(task)
safe_dataset = sanitize_filename(dataset_name)
# Include both task and dataset in filename for clarity
filename = f"data_{safe_task}_{safe_dataset}.pkl"
# If combined name is too long, shorten it
if len(filename) > 255:
# Use shorter version with hash
filename = f"data_{safe_task[:50]}_{safe_dataset[:50]}_{hashlib.md5(f'{task}||{dataset_name}'.encode()).hexdigest()[:8]}.pkl"
return DATASET_DATA_DIR / filename
def cache_exists():
"""Check if cache structure exists."""
print(f"{TASKS_INDEX_FILE =}")
print(f"{METRICS_INDEX_FILE =}")
print(f"{TASKS_INDEX_FILE.exists() =}")
print(f"{METRICS_INDEX_FILE.exists() =}")
return TASKS_INDEX_FILE.exists() and METRICS_INDEX_FILE.exists()
def build_disk_based_cache():
"""Build cache with minimal memory usage - process dataset in streaming fashion."""
import os
print("Michael test", os.path.isdir("./pwc_cache"))
print("=" * 60)
print("=" * 60)
print("Building disk-based cache (one-time operation)...")
print("=" * 60)
# Initialize tracking structures (kept small)
tasks_set = set()
metrics_index = {}
print("\n[1/4] Streaming dataset and building cache...")
# Load dataset in streaming mode to save memory
ds = load_dataset("pwc-archive/evaluation-tables", split="train", streaming=False)
total_items = len(ds)
processed_count = 0
dataset_count = 0
for idx, item in tqdm(enumerate(ds), total=total_items):
# Progress indicator
task = item['task']
if not task:
continue
tasks_set.add(task)
# Load existing task data from disk or create new
task_file = get_task_filename(task)
if task_file.exists():
with open(task_file, 'rb') as f:
task_data = pickle.load(f)
else:
task_data = {
'categories': set(),
'datasets': set(),
'date_range': {'min': None, 'max': None}
}
# Update task data
if item['categories']:
task_data['categories'].update(item['categories'])
# Process datasets
if item['datasets']:
for dataset in item['datasets']:
if not isinstance(dataset, dict) or 'dataset' not in dataset:
continue
dataset_name = dataset['dataset']
dataset_file = get_dataset_filename(task, dataset_name)
# Skip if already processed
if dataset_file.exists():
task_data['datasets'].add(dataset_name)
continue
task_data['datasets'].add(dataset_name)
# Process SOTA data
if 'sota' not in dataset or 'rows' not in dataset['sota']:
continue
models_data = []
for row in dataset['sota']['rows']:
if not isinstance(row, dict):
continue
model_name = row.get('model_name', 'Unknown Model')
# Extract metrics
metrics = {}
if 'metrics' in row and isinstance(row['metrics'], dict):
for metric_name, metric_value in row['metrics'].items():
if metric_value is not None:
metrics[metric_name] = metric_value
# Track metric metadata
if metric_name not in metrics_index:
metrics_index[metric_name] = {
'count': 0,
'is_lower_better': any(kw in metric_name.lower()
for kw in ['error', 'loss', 'time', 'cost'])
}
metrics_index[metric_name]['count'] += 1
# Parse date
paper_date = row.get('paper_date')
try:
if paper_date and isinstance(paper_date, str):
release_date = pd.to_datetime(paper_date)
else:
release_date = pd.to_datetime('2020-01-01')
except:
release_date = pd.to_datetime('2020-01-01')
# Update date range
if task_data['date_range']['min'] is None or release_date < task_data['date_range']['min']:
task_data['date_range']['min'] = release_date
if task_data['date_range']['max'] is None or release_date > task_data['date_range']['max']:
task_data['date_range']['max'] = release_date
# Build model entry
model_entry = {
'model_name': model_name,
'release_date': release_date,
'paper_date': row.get('paper_date', ''), # Store raw paper_date for dynamic parsing
'paper_url': row.get('paper_url', ''),
'paper_title': row.get('paper_title', ''),
'code_url': row.get('code_links', [''])[0] if row.get('code_links') else '',
**metrics
}
models_data.append(model_entry)
if models_data:
df = pd.DataFrame(models_data)
df = df.sort_values('release_date')
# Save dataset to its own file
with open(dataset_file, 'wb') as f:
pickle.dump(df, f, protocol=pickle.HIGHEST_PROTOCOL)
dataset_count += 1
# Clear DataFrame from memory
del df
del models_data
# Save updated task data back to disk
with open(task_file, 'wb') as f:
# Convert sets to lists for serialization
task_data_to_save = {
'categories': sorted(list(task_data['categories'])),
'datasets': sorted(list(task_data['datasets'])),
'date_range': task_data['date_range']
}
pickle.dump(task_data_to_save, f, protocol=pickle.HIGHEST_PROTOCOL)
# Clear task data from memory
del task_data
processed_count += 1
print(f"\nβœ“ Processed {len(tasks_set)} tasks and {dataset_count} datasets")
print("\n[2/4] Saving index files...")
# Save tasks index (small file)
tasks_list = sorted(list(tasks_set))
with open(TASKS_INDEX_FILE, 'w') as f:
json.dump(tasks_list, f)
print(f" βœ“ Saved tasks index ({len(tasks_list)} tasks)")
# Save metrics index
with open(METRICS_INDEX_FILE, 'w') as f:
json.dump(metrics_index, f, indent=2)
print(f" βœ“ Saved metrics index ({len(metrics_index)} metrics)")
print("\n[3/4] Calculating cache statistics...")
# Calculate total cache size
total_size = 0
for file in TASK_DATA_DIR.glob("*.pkl"):
total_size += file.stat().st_size
for file in DATASET_DATA_DIR.glob("*.pkl"):
total_size += file.stat().st_size
print(f" βœ“ Total cache size: {total_size / 1024 / 1024:.1f} MB")
print(f" βœ“ Task files: {len(list(TASK_DATA_DIR.glob('*.pkl')))}")
print(f" βœ“ Dataset files: {len(list(DATASET_DATA_DIR.glob('*.pkl')))}")
print("\n[4/4] Cache building complete!")
print("=" * 60)
return tasks_list
def load_tasks_index():
"""Load just the task list from disk."""
with open(TASKS_INDEX_FILE, 'r') as f:
return json.load(f)
def load_task_data(task):
"""Load data for a specific task from disk."""
task_file = get_task_filename(task)
if task_file.exists():
with open(task_file, 'rb') as f:
return pickle.load(f)
return None
def load_dataset_data(task, dataset_name):
"""Load a specific dataset from disk."""
dataset_file = get_dataset_filename(task, dataset_name)
if dataset_file.exists():
with open(dataset_file, 'rb') as f:
return pickle.load(f)
return pd.DataFrame()
def load_metrics_index():
"""Load metrics index from disk."""
if METRICS_INDEX_FILE.exists():
with open(METRICS_INDEX_FILE, 'r') as f:
return json.load(f)
return {}
# Initialize - build cache if doesn't exist
if cache_exists():
print("Loading task index from disk...")
TASKS = load_tasks_index()
print(f"βœ“ Loaded {len(TASKS)} tasks")
else:
TASKS = build_disk_based_cache()
# Load metrics index once (it's small)
METRICS_INDEX = load_metrics_index()
# Memory-efficient accessor functions
def get_tasks():
"""Get all tasks from index."""
return TASKS
def get_task_data(task):
"""Load task data from disk on-demand."""
return load_task_data(task)
def get_categories(task):
"""Get categories for a task (loads from disk)."""
task_data = get_task_data(task)
return task_data['categories'] if task_data else []
def get_datasets_for_task(task):
"""Get datasets for a task (loads from disk)."""
task_data = get_task_data(task)
return task_data['datasets'] if task_data else []
def get_cached_model_data(task, dataset_name):
"""Load dataset from disk on-demand."""
return load_dataset_data(task, dataset_name)
def parse_paper_date(paper_date, paper_title="", paper_url=""):
"""Parse paper date with improved fallback strategies."""
import re
# Try to parse the raw paper_date if available
if paper_date and isinstance(paper_date, str) and paper_date.strip():
try:
# Try common date formats
date_formats = [
'%Y-%m-%d',
'%Y/%m/%d',
'%d-%m-%Y',
'%d/%m/%Y',
'%Y-%m',
'%Y/%m',
'%Y'
]
for fmt in date_formats:
try:
return pd.to_datetime(paper_date.strip(), format=fmt)
except:
continue
# Try pandas automatic parsing
return pd.to_datetime(paper_date.strip())
except:
pass
# Fallback: try to extract year from paper title or URL
year_pattern = r'\b(19[5-9]\d|20[0-9]\d)\b' # Match 1950-2099
# Look for year in paper title
if paper_title:
years = re.findall(year_pattern, str(paper_title))
if years:
try:
year = max(years) # Use the latest year found
return pd.to_datetime(f'{year}-01-01')
except:
pass
# Look for year in paper URL
if paper_url:
years = re.findall(year_pattern, str(paper_url))
if years:
try:
year = max(years) # Use the latest year found
return pd.to_datetime(f'{year}-01-01')
except:
pass
# Final fallback: return None instead of a default year
return None
def get_task_statistics(task):
"""Get statistics about a task."""
return {}
def create_sota_plot(df, metric):
"""Create a plot showing model performance evolution over time.
Args:
df: DataFrame with model data
metric: Metric name to plot on y-axis
"""
if df.empty or metric not in df.columns:
fig = go.Figure()
fig.add_annotation(
text="No data available for this metric",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=20)
)
fig.update_layout(
title="No Data Available",
height=600,
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
# Remove rows where the metric is NaN
df_clean = df.dropna(subset=[metric]).copy()
if df_clean.empty:
fig = go.Figure()
fig.add_annotation(
text="No valid data points for this metric",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=20)
)
fig.update_layout(
title="No Data Available",
height=600,
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
# Convert metric column to numeric, handling any string values
try:
df_clean[metric] = pd.to_numeric(
df_clean[metric].apply(lambda x: x.strip()[:-1] if isinstance(x, str) and x.strip().endswith("%") else x),
errors='coerce')
# Remove any rows that couldn't be converted to numeric
df_clean = df_clean.dropna(subset=[metric])
if df_clean.empty:
fig = go.Figure()
fig.add_annotation(
text=f"No numeric data available for metric: {metric}",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=20)
)
fig.update_layout(
title="No Numeric Data Available",
height=600,
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Error processing metric data: {str(e)}",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Data Processing Error",
height=600,
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
# Recalculate release dates dynamically from raw paper_date if available
df_processed = df_clean.copy()
if 'paper_date' in df_processed.columns:
# Parse dates dynamically using improved logic
df_processed['dynamic_release_date'] = df_processed.apply(
lambda row: parse_paper_date(
row.get('paper_date', ''),
row.get('paper_title', ''),
row.get('paper_url', '')
), axis=1
)
# Use dynamic dates if available, otherwise fallback to original release_date
df_processed['final_release_date'] = df_processed['dynamic_release_date'].fillna(df_processed['release_date'])
else:
# If no paper_date column, use existing release_date
df_processed['final_release_date'] = df_processed['release_date']
# Filter out rows with no valid date
df_with_dates = df_processed[df_processed['final_release_date'].notna()].copy()
if df_with_dates.empty:
# If no valid dates, return empty plot
fig = go.Figure()
fig.add_annotation(
text="No valid dates available for this dataset",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=20)
)
fig.update_layout(
title="No Date Data Available",
height=600,
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
# Sort by final release date
df_sorted = df_with_dates.sort_values('final_release_date').copy()
# Check if metric is lower-better
is_lower_better = False
if metric in METRICS_INDEX:
is_lower_better = METRICS_INDEX[metric].get('is_lower_better', False)
else:
is_lower_better = any(keyword in metric.lower() for keyword in ['error', 'loss', 'time', 'cost'])
if is_lower_better:
df_sorted['cumulative_best'] = df_sorted[metric].cummin()
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best']
else:
df_sorted['cumulative_best'] = df_sorted[metric].cummax()
df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best']
# Get SOTA models
sota_df = df_sorted[df_sorted['is_sota']].copy()
# Use the dynamically calculated dates for x-axis
x_values = df_sorted['final_release_date']
x_axis_title = 'Release Date'
# Create the plot
fig = go.Figure()
# Add all models as scatter points
fig.add_trace(go.Scatter(
x=x_values,
y=df_sorted[metric],
mode='markers',
name='All models',
marker=dict(
color=['#00CED1' if is_sota else 'lightgray'
for is_sota in df_sorted['is_sota']],
size=8,
opacity=0.7
),
text=df_sorted['model_name'],
customdata=df_sorted[['paper_title', 'paper_url', 'code_url']],
hovertemplate='<b>%{text}</b><br>' +
f'{metric}: %{{y:.4f}}<br>' +
'Date: %{x}<br>' +
'Paper: %{customdata[0]}<br>' +
'<extra></extra>'
))
# Add SOTA line
fig.add_trace(go.Scatter(
x=x_values,
y=df_sorted['cumulative_best'],
mode='lines',
name=f'SOTA (cumulative {"min" if is_lower_better else "max"})',
line=dict(color='#00CED1', width=2, dash='solid'),
hovertemplate=f'SOTA {metric}: %{{y:.4f}}<br>{x_axis_title}: %{{x}}<extra></extra>'
))
# Add labels for SOTA models
if not sota_df.empty:
# Calculate dynamic offset based on data range
y_range = df_sorted[metric].max() - df_sorted[metric].min()
# Use a percentage of the range for offset, with minimum and maximum bounds
if y_range > 0:
base_offset = y_range * 0.03 # 3% of the data range
# Ensure minimum offset for readability and maximum to prevent excessive spacing
label_offset = max(y_range * 0.01, min(base_offset, y_range * 0.08))
else:
# Fallback for when all values are the same
label_offset = 1
# Track label positions to prevent overlaps
previous_labels = []
# For date-based x-axis, use date separation
try:
date_range = (df_sorted['final_release_date'].max() - df_sorted['final_release_date'].min()).days
min_separation = max(30, date_range * 0.05) # Minimum 30 days or 5% of range
except (TypeError, AttributeError):
# Fallback if date calculation fails
min_separation = 30
for i, (_, row) in enumerate(sota_df.iterrows()):
# Determine base label position based on metric type
if is_lower_better:
# For lower-better metrics, place label above the point (negative ay)
base_ay_offset = -label_offset
base_yshift = -8
alternate_multiplier = -1
else:
# For higher-better metrics, place label below the point (positive ay)
base_ay_offset = label_offset
base_yshift = 8
alternate_multiplier = 1
# Check for collision with previous labels
current_x = row['final_release_date']
collision_detected = False
for prev_x, prev_ay in previous_labels:
try:
x_diff = abs((current_x - prev_x).days)
if x_diff < min_separation:
collision_detected = True
break
except (TypeError, AttributeError):
# Skip collision detection if calculation fails
continue
# Adjust position if collision detected
if collision_detected:
# Alternate the label position (above/below) to avoid overlap
ay_offset = base_ay_offset + (alternate_multiplier * label_offset * 0.7 * (i % 2))
yshift = base_yshift + (alternate_multiplier * 12 * (i % 2))
else:
ay_offset = base_ay_offset
yshift = base_yshift
# Add the annotation
fig.add_annotation(
x=current_x,
y=row[metric],
text=row['model_name'][:25] + '...' if len(row['model_name']) > 25 else row['model_name'],
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1,
arrowcolor='#00CED1', # Match the SOTA line color
ax=0,
ay=ay_offset, # Dynamic offset based on data range and collision detection
yshift=yshift, # Fine-tune positioning
font=dict(size=8, color='#333333'),
bgcolor='rgba(255, 255, 255, 0.9)', # Semi-transparent background
borderwidth=0 # Remove border
)
# Track this label position
previous_labels.append((current_x, ay_offset))
# Update layout
fig.update_layout(
title=f'SOTA Evolution: {metric}',
xaxis_title=x_axis_title,
yaxis_title=metric,
xaxis=dict(showgrid=True, gridcolor='lightgray'),
yaxis=dict(showgrid=True, gridcolor='lightgray'),
plot_bgcolor='white',
paper_bgcolor='white',
height=600,
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
hovermode='closest'
)
# Clear the DataFrame from memory after plotting
del df_clean
del df_sorted
del sota_df
return fig
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“Š Papers with Code - SOTA Evolution Visualizer")
gr.Markdown(
"Navigate through ML tasks and datasets to visualize the evolution of state-of-the-art models over time.")
gr.Markdown("*Optimized for low memory usage - data is loaded on-demand from disk*")
# Status
with gr.Row():
gr.Markdown(f"""
<div style="background-color: #f0f9ff; border-left: 4px solid #00CED1; padding: 10px; margin: 10px 0;">
<b>πŸ’Ύ Disk-Based Storage Active</b><br>
β€’ <b>{len(TASKS)}</b> tasks indexed<br>
β€’ <b>{len(METRICS_INDEX)}</b> unique metrics tracked<br>
β€’ Data loaded on-demand to minimize RAM usage
</div>
""")
# State variables
current_df = gr.State(pd.DataFrame())
current_task = gr.State(None)
# Navigation dropdowns
with gr.Row():
task_dropdown = gr.Dropdown(
choices=get_tasks(),
label="Select Task",
interactive=True
)
category_dropdown = gr.Dropdown(
choices=[],
label="Categories (info only)",
interactive=False
)
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=[],
label="Select Dataset",
interactive=True
)
metric_dropdown = gr.Dropdown(
choices=[],
label="Select Metric",
interactive=True
)
# Info display
info_text = gr.Markdown("πŸ‘† Please select a task to begin")
# Plot
plot = gr.Plot(label="SOTA Evolution")
# Data display
with gr.Row():
show_data_btn = gr.Button("πŸ“‹ Show/Hide Model Data")
export_btn = gr.Button("πŸ’Ύ Export Current Data (CSV)")
clear_memory_btn = gr.Button("🧹 Clear Memory", variant="secondary")
df_display = gr.Dataframe(
label="Model Data",
visible=False
)
# Update functions
def update_task_selection(task):
"""Update dropdowns when task is selected."""
if not task:
return [], [], [], "πŸ‘† Please select a task to begin", pd.DataFrame(), None, None
# Load task data from disk
categories = get_categories(task)
datasets = get_datasets_for_task(task)
info = f"### πŸ“‚ **Task:** {task}\n"
if categories:
info += f"- **Categories:** {', '.join(categories[:3])}{'...' if len(categories) > 3 else ''} ({len(categories)} total)\n"
return (
gr.Dropdown(choices=categories, value=categories[0] if categories else None),
gr.Dropdown(choices=datasets, value=None),
gr.Dropdown(choices=[], value=None),
info,
pd.DataFrame(),
None,
task # Store current task
)
def update_dataset_selection(task, dataset_name):
"""Update when dataset is selected - loads from disk."""
if not task or not dataset_name:
return [], "", pd.DataFrame(), None
# Load dataset from disk
df = get_cached_model_data(task, dataset_name)
if df.empty:
return [], f"⚠️ No models found for dataset: {dataset_name}", df, None
# Get metric columns
exclude_cols = ['model_name', 'release_date', 'paper_date', 'paper_url', 'paper_title', 'code_url']
metric_cols = [col for col in df.columns if col not in exclude_cols]
info = f"### πŸ“Š **Dataset:** {dataset_name}\n"
info += f"- **Models:** {len(df)} models\n"
info += f"- **Metrics:** {len(metric_cols)} metrics available\n"
if not df.empty:
info += f"- **Date Range:** {df['release_date'].min().strftime('%Y-%m-%d')} to {df['release_date'].max().strftime('%Y-%m-%d')}\n"
if metric_cols:
info += f"- **Available Metrics:** {', '.join(metric_cols[:5])}{'...' if len(metric_cols) > 5 else ''}"
return (
gr.Dropdown(choices=metric_cols, value=metric_cols[0] if metric_cols else None),
info,
df,
None
)
def update_plot(df, metric):
"""Update plot when metric is selected."""
if df.empty or not metric:
return None
plot_result = create_sota_plot(df, metric)
return plot_result
def toggle_dataframe(df):
"""Toggle dataframe visibility."""
if df.empty:
return gr.Dataframe(value=pd.DataFrame(), visible=False)
# Show relevant columns
display_cols = ['model_name', 'release_date'] + [col for col in df.columns
if col not in ['model_name', 'release_date', 'paper_date',
'paper_url',
'paper_title', 'code_url']]
display_df = df[display_cols].copy()
display_df['release_date'] = display_df['release_date'].dt.strftime('%Y-%m-%d')
return gr.Dataframe(value=display_df, visible=True)
def export_data(df):
"""Export current dataframe to CSV."""
if df.empty:
return "⚠️ No data to export"
filename = f"sota_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df.to_csv(filename, index=False)
return f"βœ… Data exported to {filename} ({len(df)} models)"
def clear_memory():
"""Clear memory by forcing garbage collection."""
import gc
gc.collect()
return "βœ… Memory cleared"
# Event handlers
task_dropdown.change(
fn=update_task_selection,
inputs=task_dropdown,
outputs=[category_dropdown, dataset_dropdown,
metric_dropdown, info_text, current_df, plot, current_task]
)
dataset_dropdown.change(
fn=update_dataset_selection,
inputs=[task_dropdown, dataset_dropdown],
outputs=[metric_dropdown, info_text, current_df, plot]
)
metric_dropdown.change(
fn=update_plot,
inputs=[current_df, metric_dropdown],
outputs=plot
)
show_data_btn.click(
fn=toggle_dataframe,
inputs=current_df,
outputs=df_display
)
export_btn.click(
fn=export_data,
inputs=current_df,
outputs=info_text
)
clear_memory_btn.click(
fn=clear_memory,
inputs=[],
outputs=info_text
)
gr.Markdown("""
---
### πŸ“– How to Use
1. **Select a Task** from the first dropdown
2. **Select a Dataset** to analyze
3. **Select a Metric** to visualize
4. The plot shows SOTA model evolution over time with dynamically calculated dates
### πŸ’Ύ Memory Optimization
- Data is stored on disk and loaded on-demand
- Only the current task and dataset are kept in memory
- Use "Clear Memory" button if needed
- Infinite disk space is utilized for permanent caching
### 🎨 Plot Features
- **πŸ”΅ Cyan dots**: SOTA models when released
- **βšͺ Gray dots**: Other models
- **πŸ“ˆ Cyan line**: SOTA progression
- **πŸ” Hover**: View model details
- **🏷️ Smart Labels**: SOTA model labels positioned close to the line with intelligent collision detection
""")
def test_sota_label_positioning():
"""Test function to validate SOTA label positioning improvements."""
print("πŸ§ͺ Testing SOTA label positioning...")
# Create sample data for testing
import pandas as pd
from datetime import datetime
# Test data with different metric types (including all required columns)
test_data = {
'model_name': ['Model A', 'Model B', 'Model C', 'Model D'],
'release_date': [
datetime(2020, 1, 1),
datetime(2020, 6, 1),
datetime(2021, 1, 1),
datetime(2021, 6, 1)
],
'paper_title': ['Paper A', 'Paper B', 'Paper C', 'Paper D'],
'paper_url': ['http://example.com/a', 'http://example.com/b', 'http://example.com/c', 'http://example.com/d'],
'code_url': ['http://github.com/a', 'http://github.com/b', 'http://github.com/c', 'http://github.com/d'],
'accuracy': [0.85, 0.87, 0.90, 0.92], # Higher-better metric
'error_rate': [0.15, 0.13, 0.10, 0.08] # Lower-better metric
}
df_test = pd.DataFrame(test_data)
# Test with higher-better metric (accuracy)
print(" Testing with higher-better metric (accuracy)...")
try:
fig1 = create_sota_plot(df_test, 'accuracy')
print(" βœ… Higher-better metric test passed")
except Exception as e:
print(f" ❌ Higher-better metric test failed: {e}")
# Test with lower-better metric (error_rate)
print(" Testing with lower-better metric (error_rate)...")
try:
fig2 = create_sota_plot(df_test, 'error_rate')
print(" βœ… Lower-better metric test passed")
except Exception as e:
print(f" ❌ Lower-better metric test failed: {e}")
# Test with empty data
print(" Testing with empty dataframe...")
try:
fig3 = create_sota_plot(pd.DataFrame(), 'test_metric')
print(" βœ… Empty data test passed")
except Exception as e:
print(f" ❌ Empty data test failed: {e}")
# Test with string metric data (should handle gracefully)
print(" Testing with string metric data...")
try:
df_test_string = df_test.copy()
df_test_string['string_metric'] = ['low', 'medium', 'high', 'very_high']
fig4 = create_sota_plot(df_test_string, 'string_metric')
print(" βœ… String metric test passed (handled gracefully)")
except Exception as e:
print(f" ❌ String metric test failed: {e}")
# Test with mixed numeric/string data
print(" Testing with mixed data types...")
try:
df_test_mixed = df_test.copy()
df_test_mixed['mixed_metric'] = [0.85, 'N/A', 0.90, 0.92]
fig5 = create_sota_plot(df_test_mixed, 'mixed_metric')
print(" βœ… Mixed data test passed")
except Exception as e:
print(f" ❌ Mixed data test failed: {e}")
# Test with paper_date parsing
print(" Testing with paper_date column...")
try:
df_test_dates = df_test.copy()
df_test_dates['paper_date'] = ['2015-03-15', '2018-invalid', '2021-12-01', '2022']
fig6 = create_sota_plot(df_test_dates, 'accuracy')
print(" βœ… Paper date parsing test passed")
except Exception as e:
print(f" ❌ Paper date parsing test failed: {e}")
print("πŸŽ‰ SOTA label positioning tests completed!")
return True
demo.launch()