|
import json |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import os |
|
import numpy as np |
|
import io |
|
|
|
|
|
PIPELINE_TAGS = [ |
|
'text-generation', |
|
'text-to-image', |
|
'text-classification', |
|
'text2text-generation', |
|
'audio-to-audio', |
|
'feature-extraction', |
|
'image-classification', |
|
'translation', |
|
'reinforcement-learning', |
|
'fill-mask', |
|
'text-to-speech', |
|
'automatic-speech-recognition', |
|
'image-text-to-text', |
|
'token-classification', |
|
'sentence-similarity', |
|
'question-answering', |
|
'image-feature-extraction', |
|
'summarization', |
|
'zero-shot-image-classification', |
|
'object-detection', |
|
'image-segmentation', |
|
'image-to-image', |
|
'image-to-text', |
|
'audio-classification', |
|
'visual-question-answering', |
|
'text-to-video', |
|
'zero-shot-classification', |
|
'depth-estimation', |
|
'text-ranking', |
|
'image-to-video', |
|
'multiple-choice', |
|
'unconditional-image-generation', |
|
'video-classification', |
|
'text-to-audio', |
|
'time-series-forecasting', |
|
'any-to-any', |
|
'video-text-to-text', |
|
'table-question-answering', |
|
] |
|
|
|
|
|
MODEL_SIZE_RANGES = { |
|
"Small (<1GB)": (0, 1), |
|
"Medium (1-5GB)": (1, 5), |
|
"Large (5-20GB)": (5, 20), |
|
"X-Large (20-50GB)": (20, 50), |
|
"XX-Large (>50GB)": (50, float('inf')) |
|
} |
|
|
|
|
|
def is_audio_speech(row): |
|
tags = row.get("tags", []) |
|
pipeline_tag = row.get("pipeline_tag", "") |
|
|
|
return (pipeline_tag and ("audio" in pipeline_tag.lower() or "speech" in pipeline_tag.lower())) or \ |
|
any("audio" in tag.lower() for tag in tags) or \ |
|
any("speech" in tag.lower() for tag in tags) |
|
|
|
def is_music(row): |
|
tags = row.get("tags", []) |
|
return any("music" in tag.lower() for tag in tags) |
|
|
|
def is_robotics(row): |
|
tags = row.get("tags", []) |
|
return any("robot" in tag.lower() for tag in tags) |
|
|
|
def is_biomed(row): |
|
tags = row.get("tags", []) |
|
return any("bio" in tag.lower() for tag in tags) or \ |
|
any("medic" in tag.lower() for tag in tags) |
|
|
|
def is_timeseries(row): |
|
tags = row.get("tags", []) |
|
return any("series" in tag.lower() for tag in tags) |
|
|
|
def is_science(row): |
|
tags = row.get("tags", []) |
|
return any("science" in tag.lower() and "bigscience" not in tag for tag in tags) |
|
|
|
def is_video(row): |
|
tags = row.get("tags", []) |
|
return any("video" in tag.lower() for tag in tags) |
|
|
|
def is_image(row): |
|
tags = row.get("tags", []) |
|
return any("image" in tag.lower() for tag in tags) |
|
|
|
def is_text(row): |
|
tags = row.get("tags", []) |
|
return any("text" in tag.lower() for tag in tags) |
|
|
|
|
|
def is_in_size_range(row, size_range): |
|
if size_range is None: |
|
return True |
|
|
|
min_size, max_size = MODEL_SIZE_RANGES[size_range] |
|
|
|
|
|
if "params" in row and pd.notna(row["params"]): |
|
try: |
|
|
|
size_gb = float(row["params"]) / (1024 * 1024 * 1024) |
|
return min_size <= size_gb < max_size |
|
except (ValueError, TypeError): |
|
return False |
|
|
|
return False |
|
|
|
TAG_FILTER_FUNCS = { |
|
"Audio & Speech": is_audio_speech, |
|
"Time series": is_timeseries, |
|
"Robotics": is_robotics, |
|
"Music": is_music, |
|
"Video": is_video, |
|
"Images": is_image, |
|
"Text": is_text, |
|
"Biomedical": is_biomed, |
|
"Sciences": is_science, |
|
} |
|
|
|
def extract_org_from_id(model_id): |
|
"""Extract organization name from model ID""" |
|
if "/" in model_id: |
|
return model_id.split("/")[0] |
|
return "unaffiliated" |
|
|
|
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None): |
|
"""Process DataFrame into treemap format with filters applied""" |
|
|
|
filtered_df = df.copy() |
|
|
|
|
|
if tag_filter and tag_filter in TAG_FILTER_FUNCS: |
|
filter_func = TAG_FILTER_FUNCS[tag_filter] |
|
filtered_df = filtered_df[filtered_df.apply(filter_func, axis=1)] |
|
|
|
if pipeline_filter: |
|
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter] |
|
|
|
if size_filter and size_filter in MODEL_SIZE_RANGES: |
|
|
|
def check_size(row): |
|
return is_in_size_range(row, size_filter) |
|
|
|
filtered_df = filtered_df[filtered_df.apply(check_size, axis=1)] |
|
|
|
|
|
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id) |
|
|
|
|
|
if skip_orgs and len(skip_orgs) > 0: |
|
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)] |
|
|
|
|
|
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index() |
|
org_totals = org_totals.sort_values(by=count_by, ascending=False) |
|
|
|
|
|
top_orgs = org_totals.head(top_k)["organization"].tolist() |
|
|
|
|
|
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)] |
|
|
|
|
|
treemap_data = filtered_df[["id", "organization", count_by]].copy() |
|
|
|
|
|
treemap_data["root"] = "models" |
|
|
|
|
|
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0) |
|
|
|
return treemap_data |
|
|
|
def create_treemap(treemap_data, count_by, title=None): |
|
"""Create a Plotly treemap from the prepared data""" |
|
if treemap_data.empty: |
|
|
|
fig = px.treemap( |
|
names=["No data matches the selected filters"], |
|
values=[1] |
|
) |
|
fig.update_layout( |
|
title="No data matches the selected filters", |
|
margin=dict(t=50, l=25, r=25, b=25) |
|
) |
|
return fig |
|
|
|
|
|
fig = px.treemap( |
|
treemap_data, |
|
path=["root", "organization", "id"], |
|
values=count_by, |
|
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization", |
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
) |
|
|
|
|
|
fig.update_layout( |
|
margin=dict(t=50, l=25, r=25, b=25) |
|
) |
|
|
|
|
|
fig.update_traces( |
|
textinfo="label+value+percent root", |
|
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>" |
|
) |
|
|
|
return fig |
|
|
|
def load_models_csv(): |
|
|
|
df = pd.read_csv('models.csv') |
|
|
|
|
|
def process_tags(tags_str): |
|
if pd.isna(tags_str): |
|
return [] |
|
|
|
|
|
tags_str = tags_str.strip("[]").replace("'", "") |
|
tags = [tag.strip() for tag in tags_str.split() if tag.strip()] |
|
return tags |
|
|
|
df['tags'] = df['tags'].apply(process_tags) |
|
|
|
return df |
|
|
|
|
|
with gr.Blocks() as demo: |
|
models_data = gr.State() |
|
loading_complete = gr.State(False) |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
# HuggingFace Models TreeMap Visualization |
|
|
|
This app shows how different organizations contribute to the HuggingFace ecosystem with their models. |
|
Use the filters to explore models by different metrics, tags, pipelines, and model sizes. |
|
|
|
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
count_by_dropdown = gr.Dropdown( |
|
label="Metric", |
|
choices=[ |
|
("Downloads (last 30 days)", "downloads"), |
|
("Downloads (All Time)", "downloadsAllTime"), |
|
("Likes", "likes") |
|
], |
|
value="downloads", |
|
info="Select the metric to determine box sizes" |
|
) |
|
|
|
filter_choice_radio = gr.Radio( |
|
label="Filter Type", |
|
choices=["None", "Tag Filter", "Pipeline Filter"], |
|
value="None", |
|
info="Choose how to filter the models" |
|
) |
|
|
|
tag_filter_dropdown = gr.Dropdown( |
|
label="Select Tag", |
|
choices=list(TAG_FILTER_FUNCS.keys()), |
|
value=None, |
|
visible=False, |
|
info="Filter models by domain/category" |
|
) |
|
|
|
pipeline_filter_dropdown = gr.Dropdown( |
|
label="Select Pipeline Tag", |
|
choices=PIPELINE_TAGS, |
|
value=None, |
|
visible=False, |
|
info="Filter models by specific pipeline" |
|
) |
|
|
|
size_filter_dropdown = gr.Dropdown( |
|
label="Model Size Filter", |
|
choices=["None"] + list(MODEL_SIZE_RANGES.keys()), |
|
value="None", |
|
info="Filter models by their size (using params column)" |
|
) |
|
|
|
top_k_slider = gr.Slider( |
|
label="Number of Top Organizations", |
|
minimum=5, |
|
maximum=50, |
|
value=25, |
|
step=5, |
|
info="Number of top organizations to include" |
|
) |
|
|
|
skip_orgs_textbox = gr.Textbox( |
|
label="Organizations to Skip (comma-separated)", |
|
placeholder="e.g., OpenAI, Google", |
|
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski" |
|
) |
|
|
|
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False) |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot() |
|
stats_output = gr.Markdown("*Generate a plot to see statistics*") |
|
|
|
|
|
def load_models_csv(): |
|
df = pd.read_csv('models.csv') |
|
|
|
def process_tags(tags_str): |
|
if pd.isna(tags_str): |
|
return [] |
|
tags_str = tags_str.strip("[]").replace("'", "") |
|
tags = [tag.strip() for tag in tags_str.split() if tag.strip()] |
|
return tags |
|
|
|
df['tags'] = df['tags'].apply(process_tags) |
|
return df, True |
|
|
|
|
|
def enable_plot_button(loaded): |
|
return gr.update(interactive=loaded) |
|
|
|
loading_complete.change( |
|
fn=enable_plot_button, |
|
inputs=[loading_complete], |
|
outputs=[generate_plot_button] |
|
) |
|
|
|
|
|
def update_filter_visibility(filter_choice): |
|
if filter_choice == "Tag Filter": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
elif filter_choice == "Pipeline Filter": |
|
return gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
filter_choice_radio.change( |
|
fn=update_filter_visibility, |
|
inputs=[filter_choice_radio], |
|
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] |
|
) |
|
|
|
|
|
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df): |
|
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: |
|
return None, "Error: Data is still loading. Please wait a moment and try again." |
|
|
|
selected_tag_filter = None |
|
selected_pipeline_filter = None |
|
selected_size_filter = None |
|
|
|
if filter_choice == "Tag Filter": |
|
selected_tag_filter = tag_filter |
|
elif filter_choice == "Pipeline Filter": |
|
selected_pipeline_filter = pipeline_filter |
|
|
|
if size_filter != "None": |
|
selected_size_filter = size_filter |
|
|
|
skip_orgs = [] |
|
if skip_orgs_text and skip_orgs_text.strip(): |
|
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()] |
|
|
|
treemap_data = make_treemap_data( |
|
df=data_df, |
|
count_by=count_by, |
|
top_k=top_k, |
|
tag_filter=selected_tag_filter, |
|
pipeline_filter=selected_pipeline_filter, |
|
size_filter=selected_size_filter, |
|
skip_orgs=skip_orgs |
|
) |
|
|
|
title_labels = { |
|
"downloads": "Downloads (last 30 days)", |
|
"downloadsAllTime": "Downloads (All Time)", |
|
"likes": "Likes" |
|
} |
|
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization" |
|
|
|
fig = create_treemap( |
|
treemap_data=treemap_data, |
|
count_by=count_by, |
|
title=title_text |
|
) |
|
|
|
if treemap_data.empty: |
|
stats_md = "No data matches the selected filters." |
|
else: |
|
total_models = len(treemap_data) |
|
total_value = treemap_data[count_by].sum() |
|
|
|
|
|
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5) |
|
|
|
|
|
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5) |
|
|
|
|
|
stats_md = f""" |
|
## Statistics |
|
- **Total models shown**: {total_models:,} |
|
- **Total {count_by}**: {int(total_value):,} |
|
|
|
## Top Organizations by {count_by.capitalize()} |
|
|
|
| Organization | {count_by.capitalize()} | % of Total | |
|
|--------------|-------------:|----------:| |
|
""" |
|
|
|
|
|
for org, value in top_5_orgs.items(): |
|
percentage = (value / total_value) * 100 |
|
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n" |
|
|
|
|
|
stats_md += f""" |
|
## Top Models by {count_by.capitalize()} |
|
|
|
| Model | {count_by.capitalize()} | % of Total | |
|
|-------|-------------:|----------:| |
|
""" |
|
|
|
|
|
for _, row in top_5_models.iterrows(): |
|
model_id = row["id"] |
|
value = row[count_by] |
|
percentage = (value / total_value) * 100 |
|
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n" |
|
|
|
|
|
if skip_orgs: |
|
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*" |
|
|
|
return fig, stats_md |
|
|
|
|
|
demo.load( |
|
fn=load_models_csv, |
|
inputs=[], |
|
outputs=[models_data, loading_complete] |
|
) |
|
|
|
generate_plot_button.click( |
|
fn=generate_plot_on_click, |
|
inputs=[ |
|
count_by_dropdown, |
|
filter_choice_radio, |
|
tag_filter_dropdown, |
|
pipeline_filter_dropdown, |
|
size_filter_dropdown, |
|
top_k_slider, |
|
skip_orgs_textbox, |
|
models_data |
|
], |
|
outputs=[plot_output, stats_output] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |