import json
import gradio as gr
import pandas as pd
import plotly.express as px
import os
import numpy as np
import io
import duckdb
# Define pipeline tags
PIPELINE_TAGS = [
'text-generation',
'text-to-image',
'text-classification',
'text2text-generation',
'audio-to-audio',
'feature-extraction',
'image-classification',
'translation',
'reinforcement-learning',
'fill-mask',
'text-to-speech',
'automatic-speech-recognition',
'image-text-to-text',
'token-classification',
'sentence-similarity',
'question-answering',
'image-feature-extraction',
'summarization',
'zero-shot-image-classification',
'object-detection',
'image-segmentation',
'image-to-image',
'image-to-text',
'audio-classification',
'visual-question-answering',
'text-to-video',
'zero-shot-classification',
'depth-estimation',
'text-ranking',
'image-to-video',
'multiple-choice',
'unconditional-image-generation',
'video-classification',
'text-to-audio',
'time-series-forecasting',
'any-to-any',
'video-text-to-text',
'table-question-answering',
]
# Model size categories in GB
MODEL_SIZE_RANGES = {
"Small (<1GB)": (0, 1),
"Medium (1-5GB)": (1, 5),
"Large (5-20GB)": (5, 20),
"X-Large (20-50GB)": (20, 50),
"XX-Large (>50GB)": (50, float('inf'))
}
# Filter functions for tags - UPDATED to use cached columns
def is_audio_speech(row):
# Use cached column instead of recalculating
return row['is_audio_speech']
def is_music(row):
# Use cached column instead of recalculating
return row['has_music']
def is_robotics(row):
# Use cached column instead of recalculating
return row['has_robot']
def is_biomed(row):
# Use cached column instead of recalculating
return row['is_biomed']
def is_timeseries(row):
# Use cached column instead of recalculating
return row['has_series']
def is_science(row):
# Use cached column instead of recalculating
return row['has_science']
def is_video(row):
# Use cached column instead of recalculating
return row['has_video']
def is_image(row):
# Use cached column instead of recalculating
return row['has_image']
def is_text(row):
# Use cached column instead of recalculating
return row['has_text']
def is_image(row):
tags = row.get("tags", [])
# Check if tags exists and is not empty
if tags is not None:
# For numpy arrays
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
# Convert numpy array to list
tags_list = tags.tolist()
return any("image" in str(tag).lower() for tag in tags_list)
# For regular lists
elif isinstance(tags, list):
return any("image" in str(tag).lower() for tag in tags)
# For string tags
elif isinstance(tags, str):
return "image" in tags.lower()
return False
def is_text(row):
tags = row.get("tags", [])
# Check if tags exists and is not empty
if tags is not None:
# For numpy arrays
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
# Convert numpy array to list
tags_list = tags.tolist()
return any("text" in str(tag).lower() for tag in tags_list)
# For regular lists
elif isinstance(tags, list):
return any("text" in str(tag).lower() for tag in tags)
# For string tags
elif isinstance(tags, str):
return "text" in tags.lower()
return False
def extract_model_size(safetensors_data):
"""Extract model size in GB from safetensors data"""
try:
if pd.isna(safetensors_data):
return 0
# If it's already a dictionary, use it directly
if isinstance(safetensors_data, dict):
if 'total' in safetensors_data:
try:
size_bytes = float(safetensors_data['total'])
return size_bytes / (1024 * 1024 * 1024) # Convert to GB
except (ValueError, TypeError):
pass
# If it's a string, try to parse it as JSON
elif isinstance(safetensors_data, str):
try:
data_dict = json.loads(safetensors_data)
if 'total' in data_dict:
try:
size_bytes = float(data_dict['total'])
return size_bytes / (1024 * 1024 * 1024) # Convert to GB
except (ValueError, TypeError):
pass
except:
pass
return 0
except Exception as e:
print(f"Error extracting model size: {e}")
return 0
# Add model size filter function - UPDATED to use cached size_category column
def is_in_size_range(row, size_range):
"""Check if a model is in the specified size range using pre-calculated size category"""
if size_range is None or size_range == "None":
return True
# Simply compare with cached size_category
return row['size_category'] == size_range
TAG_FILTER_FUNCS = {
"Audio & Speech": is_audio_speech,
"Time series": is_timeseries,
"Robotics": is_robotics,
"Music": is_music,
"Video": is_video,
"Images": is_image,
"Text": is_text,
"Biomedical": is_biomed,
"Sciences": is_science,
}
def extract_org_from_id(model_id):
"""Extract organization name from model ID"""
if "/" in model_id:
return model_id.split("/")[0]
return "unaffiliated"
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
"""Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns"""
# Create a copy to avoid modifying the original
filtered_df = df.copy()
# Apply filters
filter_stats = {"initial": len(filtered_df)}
start_time = pd.Timestamp.now()
# Apply tag filter - OPTIMIZED to use cached columns
if tag_filter and tag_filter in TAG_FILTER_FUNCS:
print(f"Applying tag filter: {tag_filter}")
# Use direct column filtering instead of applying a function to each row
if tag_filter == "Audio & Speech":
filtered_df = filtered_df[filtered_df['is_audio_speech']]
elif tag_filter == "Music":
filtered_df = filtered_df[filtered_df['has_music']]
elif tag_filter == "Robotics":
filtered_df = filtered_df[filtered_df['has_robot']]
elif tag_filter == "Biomedical":
filtered_df = filtered_df[filtered_df['is_biomed']]
elif tag_filter == "Time series":
filtered_df = filtered_df[filtered_df['has_series']]
elif tag_filter == "Sciences":
filtered_df = filtered_df[filtered_df['has_science']]
elif tag_filter == "Video":
filtered_df = filtered_df[filtered_df['has_video']]
elif tag_filter == "Images":
filtered_df = filtered_df[filtered_df['has_image']]
elif tag_filter == "Text":
filtered_df = filtered_df[filtered_df['has_text']]
filter_stats["after_tag_filter"] = len(filtered_df)
print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
start_time = pd.Timestamp.now()
# Apply pipeline filter
if pipeline_filter:
print(f"Applying pipeline filter: {pipeline_filter}")
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
filter_stats["after_pipeline_filter"] = len(filtered_df)
print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
start_time = pd.Timestamp.now()
# Apply size filter - OPTIMIZED to use cached size_category column
if size_filter and size_filter in MODEL_SIZE_RANGES:
print(f"Applying size filter: {size_filter}")
# Use the cached size_category column directly
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
# Debug info
print(f"Size filter '{size_filter}' applied.")
print(f"Models after size filter: {len(filtered_df)}")
filter_stats["after_size_filter"] = len(filtered_df)
print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
start_time = pd.Timestamp.now()
# Add organization column
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
# Skip organizations if specified
if skip_orgs and len(skip_orgs) > 0:
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
filter_stats["after_skip_orgs"] = len(filtered_df)
# Print filter stats
print("Filter statistics:")
for stage, count in filter_stats.items():
print(f" {stage}: {count} models")
# Check if we have any data left
if filtered_df.empty:
print("Warning: No data left after applying filters!")
return pd.DataFrame() # Return empty DataFrame
# Aggregate by organization
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
org_totals = org_totals.sort_values(by=count_by, ascending=False)
# Get top organizations
top_orgs = org_totals.head(top_k)["organization"].tolist()
# Filter to only include models from top organizations
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
# Prepare data for treemap
treemap_data = filtered_df[["id", "organization", count_by]].copy()
# Add a root node
treemap_data["root"] = "models"
# Ensure numeric values
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
return treemap_data
def create_treemap(treemap_data, count_by, title=None):
"""Create a Plotly treemap from the prepared data"""
if treemap_data.empty:
# Create an empty figure with a message
fig = px.treemap(
names=["No data matches the selected filters"],
values=[1]
)
fig.update_layout(
title="No data matches the selected filters",
margin=dict(t=50, l=25, r=25, b=25)
)
return fig
# Create the treemap
fig = px.treemap(
treemap_data,
path=["root", "organization", "id"],
values=count_by,
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
color_discrete_sequence=px.colors.qualitative.Plotly
)
# Update layout
fig.update_layout(
margin=dict(t=50, l=25, r=25, b=25)
)
# Update traces for better readability
fig.update_traces(
textinfo="label+value+percent root",
hovertemplate="%{label}
%{value:,} " + count_by + "
%{percentRoot:.2%} of total"
)
return fig
def load_models_data():
"""Load models data from Hugging Face using DuckDB with caching for improved performance"""
try:
# The URL to the parquet file
parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
print("Fetching data from Hugging Face models.parquet...")
# Based on the column names provided, we can directly select the columns we need
# Note: We need to select safetensors to get the model size information
try:
query = """
SELECT
id,
downloads,
downloadsAllTime,
likes,
pipeline_tag,
tags,
safetensors
FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')
"""
df = duckdb.sql(query).df()
except Exception as sql_error:
print(f"Error with specific column selection: {sql_error}")
# Fallback to just selecting everything and then filtering
print("Falling back to select * query...")
query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
raw_df = duckdb.sql(query).df()
# Now extract only the columns we need
needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
available_columns = set(raw_df.columns)
df = pd.DataFrame()
# Copy over columns that exist
for col in needed_columns:
if col in available_columns:
df[col] = raw_df[col]
else:
# Create empty columns for missing data
if col in ['downloads', 'downloadsAllTime', 'likes']:
df[col] = 0
elif col == 'pipeline_tag':
df[col] = ''
elif col == 'tags':
df[col] = [[] for _ in range(len(raw_df))]
elif col == 'safetensors':
df[col] = None
elif col == 'id':
# Create IDs based on index if missing
df[col] = [f"model_{i}" for i in range(len(raw_df))]
print(f"Data fetched successfully. Shape: {df.shape}")
# Check if safetensors column exists before trying to process it
if 'safetensors' in df.columns:
# Add params column derived from safetensors.total (model size in GB)
df['params'] = df['safetensors'].apply(extract_model_size)
# Debug model sizes
size_ranges = {
"Small (<1GB)": 0,
"Medium (1-5GB)": 0,
"Large (5-20GB)": 0,
"X-Large (20-50GB)": 0,
"XX-Large (>50GB)": 0
}
# Count models in each size range
for idx, row in df.iterrows():
size_gb = row['params']
if 0 <= size_gb < 1:
size_ranges["Small (<1GB)"] += 1
elif 1 <= size_gb < 5:
size_ranges["Medium (1-5GB)"] += 1
elif 5 <= size_gb < 20:
size_ranges["Large (5-20GB)"] += 1
elif 20 <= size_gb < 50:
size_ranges["X-Large (20-50GB)"] += 1
elif size_gb >= 50:
size_ranges["XX-Large (>50GB)"] += 1
print("Model size distribution:")
for size_range, count in size_ranges.items():
print(f" {size_range}: {count} models")
# CACHE SIZE CATEGORY: Add a size_category column for faster filtering
def get_size_category(size_gb):
if 0 <= size_gb < 1:
return "Small (<1GB)"
elif 1 <= size_gb < 5:
return "Medium (1-5GB)"
elif 5 <= size_gb < 20:
return "Large (5-20GB)"
elif 20 <= size_gb < 50:
return "X-Large (20-50GB)"
elif size_gb >= 50:
return "XX-Large (>50GB)"
return None
# Add cached size category column
df['size_category'] = df['params'].apply(get_size_category)
# Remove the safetensors column as we don't need it anymore
df = df.drop(columns=['safetensors'])
else:
# If no safetensors column, add empty params column
df['params'] = 0
df['size_category'] = None
# Process tags to ensure it's in the right format - FIXED
def process_tags(tags_value):
try:
if pd.isna(tags_value) or tags_value is None:
return []
# If it's a numpy array, convert to a list of strings
if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
# Note: This is the fix for the error
return [str(tag) for tag in tags_value.tolist()]
# If already a list, ensure all elements are strings
if isinstance(tags_value, list):
return [str(tag) for tag in tags_value]
# If string, try to parse as JSON or split by comma
if isinstance(tags_value, str):
try:
tags_list = json.loads(tags_value)
if isinstance(tags_list, list):
return [str(tag) for tag in tags_list]
except:
# Split by comma if JSON parsing fails
return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
# Last resort, convert to string and return as a single tag
return [str(tags_value)]
except Exception as e:
print(f"Error processing tags: {e}")
return []
# Check if tags column exists before trying to process it
if 'tags' in df.columns:
# Process tags column
df['tags'] = df['tags'].apply(process_tags)
# CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
print("Pre-calculating cached tag categories...")
# Helper functions to check for specific tags (simplified for caching)
def has_audio_tag(tags):
if tags and isinstance(tags, list):
return any("audio" in str(tag).lower() for tag in tags)
return False
def has_speech_tag(tags):
if tags and isinstance(tags, list):
return any("speech" in str(tag).lower() for tag in tags)
return False
def has_music_tag(tags):
if tags and isinstance(tags, list):
return any("music" in str(tag).lower() for tag in tags)
return False
def has_robot_tag(tags):
if tags and isinstance(tags, list):
return any("robot" in str(tag).lower() for tag in tags)
return False
def has_bio_tag(tags):
if tags and isinstance(tags, list):
return any("bio" in str(tag).lower() for tag in tags)
return False
def has_med_tag(tags):
if tags and isinstance(tags, list):
return any("medic" in str(tag).lower() for tag in tags)
return False
def has_series_tag(tags):
if tags and isinstance(tags, list):
return any("series" in str(tag).lower() for tag in tags)
return False
def has_science_tag(tags):
if tags and isinstance(tags, list):
return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
return False
def has_video_tag(tags):
if tags and isinstance(tags, list):
return any("video" in str(tag).lower() for tag in tags)
return False
def has_image_tag(tags):
if tags and isinstance(tags, list):
return any("image" in str(tag).lower() for tag in tags)
return False
def has_text_tag(tags):
if tags and isinstance(tags, list):
return any("text" in str(tag).lower() for tag in tags)
return False
# Add cached columns for tag categories
print("Creating cached tag columns...")
df['has_audio'] = df['tags'].apply(has_audio_tag)
df['has_speech'] = df['tags'].apply(has_speech_tag)
df['has_music'] = df['tags'].apply(has_music_tag)
df['has_robot'] = df['tags'].apply(has_robot_tag)
df['has_bio'] = df['tags'].apply(has_bio_tag)
df['has_med'] = df['tags'].apply(has_med_tag)
df['has_series'] = df['tags'].apply(has_series_tag)
df['has_science'] = df['tags'].apply(has_science_tag)
df['has_video'] = df['tags'].apply(has_video_tag)
df['has_image'] = df['tags'].apply(has_image_tag)
df['has_text'] = df['tags'].apply(has_text_tag)
# Create combined category flags for faster filtering
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
df['pipeline_tag'].str.contains('audio', case=False, na=False) |
df['pipeline_tag'].str.contains('speech', case=False, na=False))
df['is_biomed'] = df['has_bio'] | df['has_med']
print("Cached tag columns created successfully!")
else:
# If no tags column, add empty tags and set all category flags to False
df['tags'] = [[] for _ in range(len(df))]
for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
'has_bio', 'has_med', 'has_series', 'has_science',
'has_video', 'has_image', 'has_text',
'is_audio_speech', 'is_biomed']:
df[col] = False
# Fill NaN values
df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
# Ensure pipeline_tag is a string
if 'pipeline_tag' in df.columns:
df['pipeline_tag'] = df['pipeline_tag'].fillna('')
else:
df['pipeline_tag'] = ''
# Make sure all required columns exist
for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
if col not in df.columns:
if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
df[col] = 0
elif col == 'pipeline_tag':
df[col] = ''
elif col == 'tags':
df[col] = [[] for _ in range(len(df))]
elif col == 'id':
df[col] = [f"model_{i}" for i in range(len(df))]
print(f"Successfully processed {len(df)} models with cached tag and size information")
return df, True
except Exception as e:
print(f"Error loading data: {e}")
# Return an empty DataFrame and False to indicate loading failure
return pd.DataFrame(), False
# Create Gradio interface
with gr.Blocks() as demo:
models_data = gr.State()
loading_complete = gr.State(False) # Flag to indicate data load completion
with gr.Row():
gr.Markdown("""
# HuggingFace Models TreeMap Visualization
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
""")
with gr.Row():
with gr.Column(scale=1):
count_by_dropdown = gr.Dropdown(
label="Metric",
choices=[
("Downloads (last 30 days)", "downloads"),
("Downloads (All Time)", "downloadsAllTime"),
("Likes", "likes")
],
value="downloads",
info="Select the metric to determine box sizes"
)
filter_choice_radio = gr.Radio(
label="Filter Type",
choices=["None", "Tag Filter", "Pipeline Filter"],
value="None",
info="Choose how to filter the models"
)
tag_filter_dropdown = gr.Dropdown(
label="Select Tag",
choices=list(TAG_FILTER_FUNCS.keys()),
value=None,
visible=False,
info="Filter models by domain/category"
)
pipeline_filter_dropdown = gr.Dropdown(
label="Select Pipeline Tag",
choices=PIPELINE_TAGS,
value=None,
visible=False,
info="Filter models by specific pipeline"
)
size_filter_dropdown = gr.Dropdown(
label="Model Size Filter",
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
value="None",
info="Filter models by their size (using params column)"
)
top_k_slider = gr.Slider(
label="Number of Top Organizations",
minimum=5,
maximum=50,
value=25,
step=5,
info="Number of top organizations to include"
)
skip_orgs_textbox = gr.Textbox(
label="Organizations to Skip (comma-separated)",
placeholder="e.g., OpenAI, Google",
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
)
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
with gr.Column(scale=3):
plot_output = gr.Plot()
stats_output = gr.Markdown("*Loading data from Hugging Face...*")
data_info = gr.Markdown("")
# Button enablement after data load
def enable_plot_button(loaded):
return gr.update(interactive=loaded)
loading_complete.change(
fn=enable_plot_button,
inputs=[loading_complete],
outputs=[generate_plot_button]
)
# Show/hide tag/pipeline dropdown
def update_filter_visibility(filter_choice):
if filter_choice == "Tag Filter":
return gr.update(visible=True), gr.update(visible=False)
elif filter_choice == "Pipeline Filter":
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
filter_choice_radio.change(
fn=update_filter_visibility,
inputs=[filter_choice_radio],
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
)
# Function to handle data load and provide data info
def load_and_provide_info():
df, success = load_models_data()
if success:
# Generate information about the loaded data
info_text = f"""
### Data Information
- **Total models loaded**: {len(df):,}
- **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
- **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
"""
# Return the data, loading status, and info text
return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
else:
# Return empty data, failed loading status, and error message
return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
# Main generate function
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
return None, "Error: Data is still loading. Please wait a moment and try again."
selected_tag_filter = None
selected_pipeline_filter = None
selected_size_filter = None
if filter_choice == "Tag Filter":
selected_tag_filter = tag_filter
elif filter_choice == "Pipeline Filter":
selected_pipeline_filter = pipeline_filter
if size_filter != "None":
selected_size_filter = size_filter
skip_orgs = []
if skip_orgs_text and skip_orgs_text.strip():
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
treemap_data = make_treemap_data(
df=data_df,
count_by=count_by,
top_k=top_k,
tag_filter=selected_tag_filter,
pipeline_filter=selected_pipeline_filter,
size_filter=selected_size_filter,
skip_orgs=skip_orgs
)
title_labels = {
"downloads": "Downloads (last 30 days)",
"downloadsAllTime": "Downloads (All Time)",
"likes": "Likes"
}
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
fig = create_treemap(
treemap_data=treemap_data,
count_by=count_by,
title=title_text
)
if treemap_data.empty:
stats_md = "No data matches the selected filters."
else:
total_models = len(treemap_data)
total_value = treemap_data[count_by].sum()
# Get top 5 organizations
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
# Get top 5 individual models
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
# Create statistics section
stats_md = f"""
## Statistics
- **Total models shown**: {total_models:,}
- **Total {count_by}**: {int(total_value):,}
## Top Organizations by {count_by.capitalize()}
| Organization | {count_by.capitalize()} | % of Total |
|--------------|-------------:|----------:|
"""
# Add top organizations to the table
for org, value in top_5_orgs.items():
percentage = (value / total_value) * 100
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
# Add the top models table
stats_md += f"""
## Top Models by {count_by.capitalize()}
| Model | {count_by.capitalize()} | % of Total |
|-------|-------------:|----------:|
"""
# Add top models to the table
for _, row in top_5_models.iterrows():
model_id = row["id"]
value = row[count_by]
percentage = (value / total_value) * 100
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
# Add note about skipped organizations if any
if skip_orgs:
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
return fig, stats_md
# Load data at startup
demo.load(
fn=load_and_provide_info,
inputs=[],
outputs=[models_data, loading_complete, data_info, stats_output]
)
# Refresh data when button is clicked
refresh_data_button.click(
fn=load_and_provide_info,
inputs=[],
outputs=[models_data, loading_complete, data_info, stats_output]
)
generate_plot_button.click(
fn=generate_plot_on_click,
inputs=[
count_by_dropdown,
filter_choice_radio,
tag_filter_dropdown,
pipeline_filter_dropdown,
size_filter_dropdown,
top_k_slider,
skip_orgs_textbox,
models_data
],
outputs=[plot_output, stats_output]
)
if __name__ == "__main__":
demo.launch()