model-rank / app.py
unmodeled-tyler's picture
Updated app to include filtering by parameter size
8106b9b verified
import gradio as gr
from huggingface_hub import HfApi
import pandas as pd
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
import requests
# Initialize HuggingFace API
api = HfApi()
MODEL_CACHE_DURATION = timedelta(hours=1)
model_cache = {}
def _cache_key(sort: str, last: str | None, limit: int) -> tuple[str, str, int]:
return (sort, last or "all", limit)
def _store_cache(sort: str, last: str | None, limit: int, models):
model_cache[_cache_key(sort, last, limit)] = {
'models': models,
'timestamp': datetime.now()
}
def _get_cached(sort: str, last: str | None, limit: int):
entry = model_cache.get(_cache_key(sort, last, limit))
if not entry:
return None
if datetime.now() - entry['timestamp'] > MODEL_CACHE_DURATION:
return None
return entry['models']
def _convert_model_payload(payload: dict):
"""Convert raw REST payload to a SimpleNamespace matching ModelInfo attributes."""
data = payload.copy()
model_id = data.get('modelId') or data.get('id')
data['modelId'] = model_id
data.setdefault('tags', data.get('tags') or [])
data.setdefault('likes', data.get('likes', 0) or 0)
data.setdefault('downloads', data.get('downloads', 0) or 0)
data.setdefault('gated', data.get('gated', False))
data.setdefault('private', data.get('private', False))
return SimpleNamespace(**data)
def fetch_models(limit=500, sort="downloads", last: str | None = None):
"""Fetch models from HuggingFace Hub or REST API with caching."""
cached = _get_cached(sort, last, limit)
if cached is not None:
return cached
try:
if last is None:
models = list(api.list_models(
sort=sort,
direction=-1,
limit=limit,
full=True
))
else:
params = {
"sort": sort,
"direction": -1,
"limit": limit,
"last": last,
"full": "true"
}
response = requests.get(
"https://huggingface.co/api/models",
params=params,
timeout=30
)
response.raise_for_status()
raw_models = response.json()
models = [_convert_model_payload(payload) for payload in raw_models]
_store_cache(sort, last, limit, models)
return models
except Exception as e:
print(f"Error fetching models: {e}")
return []
def categorize_model(model) -> str:
"""Categorize a model as base, fine-tune, or quant."""
model_id = model.modelId.lower()
tags = [tag.lower() for tag in (model.tags or [])]
# Check for quant indicators
quant_patterns = [
'gguf', 'gptq', 'awq', 'ggml', 'exl2', 'exllamav2',
'quantized', 'quant', '-q4', '-q5', '-q6', '-q8',
'k-quant', 'k_m', 'k_s', 'k_l'
]
for pattern in quant_patterns:
if pattern in model_id or pattern in tags:
return "Quant"
# Check for fine-tune indicators
finetune_patterns = [
'finetune', 'fine-tune', 'ft', 'instruct', 'chat',
'dpo', 'rlhf', 'sft', 'lora', 'qlora'
]
# Check if model name suggests it's a fine-tune (has multiple parts with descriptive names)
parts = model_id.split('/')[-1].split('-')
if len(parts) > 2:
# Likely a fine-tune if it has descriptive suffixes
for pattern in finetune_patterns:
if pattern in model_id or pattern in tags:
return "Fine-Tune"
# Check for base model indicators
base_patterns = [
'base', 'pretrained', 'pre-trained', 'foundation'
]
for pattern in base_patterns:
if pattern in model_id or pattern in tags:
return "Base Model"
# Default heuristic: if it's from major organizations and doesn't have fine-tune indicators, likely base
major_orgs = ['meta-llama', 'mistralai', 'google', 'microsoft', 'openai', 'facebook', 'tiiuae']
org = model_id.split('/')[0]
if org in major_orgs and not any(pattern in model_id for pattern in finetune_patterns):
return "Base Model"
# Default to fine-tune for most other cases
return "Fine-Tune"
def extract_license(tags) -> str:
"""Extract primary license from model tags."""
if not tags:
return "Unknown"
for tag in tags:
if tag.startswith("license:"):
return tag.split(":", 1)[1].upper()
return "Unknown"
def extract_param_size(tags, model_id=None) -> str:
"""Extract parameter size from model tags or infer from model name."""
if not tags:
size = "Unknown"
else:
size = "Unknown"
for tag in tags:
if tag.startswith("params:"):
size = tag.split(":", 1)[1].upper()
break
if size == "Unknown" and model_id:
import re
# Common patterns in model names (case-insensitive)
patterns = [
r'(\d+(?:\.\d+)?[BM])', # 7B, 70B, 1.5M, etc.
r'(\d+)B', # 7B, 70B
r'(\d+)M', # 1.5M
]
for pattern in patterns:
match = re.search(pattern, model_id, re.IGNORECASE)
if match:
num = match.group(1).upper()
# Normalize format
if 'B' in num:
size = num.replace('B', 'B').upper()
elif 'M' in num:
size = num.replace('M', 'M').upper()
break
return size
FAMILY_KEYWORDS = {
"llama": "LLaMA",
"mistral": "Mistral",
"mixtral": "Mixtral",
"phi": "Phi",
"gemma": "Gemma",
"qwen": "Qwen",
"falcon": "Falcon",
"yi": "Yi",
"deepseek": "DeepSeek",
"openelm": "OpenELM",
"gpt-neox": "GPT-NeoX",
"opt": "OPT",
"command": "Command",
}
def detect_family(model_id: str, tags) -> str:
"""Heuristic to map model to a known family."""
lowered = model_id.lower()
tag_join = " ".join((tags or [])).lower()
for keyword, family in FAMILY_KEYWORDS.items():
if keyword in lowered or keyword in tag_join:
return family
return model_id.split("/")[-1].split("-")[0].title()
def determine_access(model) -> str:
"""Return access status string for model."""
if getattr(model, "gated", False) or getattr(model, "private", False):
return "Gated"
return "Open"
HIDDEN_GEM_DOWNLOAD_LIMIT = 2_000
HIDDEN_GEM_RATIO_THRESHOLD = 0.1 # likes per download
REPRO_TAG_KEYWORDS = [
"reproducibility",
"reproducible",
"replicate",
"benchmark",
"leaderboard",
"evaluation",
"arxiv:",
"paper",
"paperswithcode"
]
def has_reproducibility_signal(tags) -> bool:
if not tags:
return False
tags_lower = [tag.lower() for tag in tags]
for keyword in REPRO_TAG_KEYWORDS:
if any(keyword in tag for tag in tags_lower):
return True
return False
def is_hidden_gem(downloads: int, likes: int, tags) -> bool:
if downloads is None:
downloads = 0
if likes is None:
likes = 0
if downloads == 0:
ratio = likes
else:
ratio = likes / downloads
return (
downloads < HIDDEN_GEM_DOWNLOAD_LIMIT and
ratio >= HIDDEN_GEM_RATIO_THRESHOLD and
has_reproducibility_signal(tags)
)
def get_filter_options():
"""Return sorted unique families and licenses for UI controls."""
models = fetch_models()
families = set()
licenses = set()
for model in models:
families.add(detect_family(model.modelId, model.tags or []))
licenses.add(extract_license(model.tags or []))
families.discard("")
licenses.discard("")
return sorted(families), sorted(licenses)
def format_number(num):
"""Format large numbers for display."""
try:
num = int(num)
except Exception:
return "0"
if num >= 1_000_000:
return f"{num/1_000_000:.1f}M"
elif num >= 1_000:
return f"{num/1_000:.1f}K"
return str(num)
def process_models(category_filter="All", family_filter="All", license_filter="All",
access_filter="All", hidden_gems_only=False,
sort_by="downloads", max_results=50, timeframe="All Time",
active_only=False, activity_window="30d", param_size_min=None, param_size_max=None):
"""Process and filter models based on category, metadata filters, and sort preference."""
fetch_limit = 1000 if hidden_gems_only or active_only else 500
last_param = activity_window if active_only else None
sort_param = "likes" if hidden_gems_only else sort_by
models = fetch_models(limit=fetch_limit, sort=sort_param, last=last_param)
if not models:
return pd.DataFrame(columns=[
"Rank", "Model", "Downloads", "Likes", "Category",
"Family", "License", "Access", "Created"
])
# Calculate timeframe cutoff
now = datetime.now(timezone.utc)
timeframe_cutoffs = {
"Last Day": now - timedelta(days=1),
"Last Week": now - timedelta(weeks=1),
"Last Month": now - timedelta(days=30),
"Last 3 Months": now - timedelta(days=90),
"All Time": None
}
cutoff_date = timeframe_cutoffs.get(timeframe)
# Process model data
model_data = []
for model in models:
# Determine model date
model_date = getattr(model, 'createdAt', None) or getattr(model, 'lastModified', None)
if model_date is not None:
if isinstance(model_date, datetime):
if model_date.tzinfo is None:
model_date = model_date.replace(tzinfo=timezone.utc)
else:
model_date = model_date.astimezone(timezone.utc)
else:
try:
iso_str = str(model_date)
if iso_str.endswith('Z'):
iso_str = iso_str[:-1] + '+00:00'
model_date = datetime.fromisoformat(iso_str)
if model_date.tzinfo is None:
model_date = model_date.replace(tzinfo=timezone.utc)
else:
model_date = model_date.astimezone(timezone.utc)
except Exception:
model_date = None
if cutoff_date is not None:
if model_date is None or model_date < cutoff_date:
continue
license_tag = extract_license(model.tags or [])
family = detect_family(model.modelId, model.tags or [])
access = determine_access(model)
hidden_gem = is_hidden_gem(getattr(model, 'downloads', 0), getattr(model, 'likes', 0), model.tags or [])
category = categorize_model(model)
param_size = extract_param_size(model.tags or [], model.modelId)
# Apply filters
if category_filter != "All" and category != category_filter:
continue
if family_filter != "All" and family != family_filter:
continue
if license_filter != "All" and license_tag != license_filter:
continue
if access_filter != "All" and access != access_filter:
continue
if hidden_gems_only and not hidden_gem:
continue
# Apply parameter size filter if specified
if param_size_min is not None and param_size_min > 0:
if param_size == "Unknown":
continue
try:
# Convert parameter size to billions for comparison
if param_size.endswith('B'):
size_b = float(param_size[:-1])
elif param_size.endswith('M'):
size_b = float(param_size[:-1]) / 1000
else:
continue
if size_b < param_size_min:
continue
except ValueError:
continue
if param_size_max is not None and param_size_max < 100:
if param_size == "Unknown":
continue
try:
# Convert parameter size to billions for comparison
if param_size.endswith('B'):
size_b = float(param_size[:-1])
elif param_size.endswith('M'):
size_b = float(param_size[:-1]) / 1000
else:
continue
if size_b > param_size_max:
continue
except ValueError:
continue
downloads = getattr(model, 'downloads', 0) or 0
likes = getattr(model, 'likes', 0) or 0
model_id = model.modelId
author = model_id.split('/')[0] if '/' in model_id else 'N/A'
name = model_id.split('/')[-1] if '/' in model_id else model_id
created_str = model_date.strftime("%Y-%m-%d") if model_date else "N/A"
model_data.append({
'model_id': model_id,
'downloads': downloads,
'likes': likes,
'category': category,
'author': author,
'name': name,
'created': created_str,
'license': license_tag,
'family': family,
'access': access,
'hidden_gem': hidden_gem,
'param_size': param_size,
})
# Sort models
if sort_by == "downloads":
model_data.sort(key=lambda x: x['downloads'], reverse=True)
elif sort_by == "likes":
model_data.sort(key=lambda x: x['likes'], reverse=True)
else:
model_data.sort(key=lambda x: x['downloads'], reverse=True)
# Limit results
model_data = model_data[:int(max_results) if max_results is not None else 50]
# Create DataFrame for display
df_data = []
for idx, model in enumerate(model_data, 1):
gem_badge = "💎 " if model['hidden_gem'] else ""
display_label = f"{gem_badge}{model['model_id']}"
link = f"https://huggingface.co/{model['model_id']}"
df_data.append({
"Rank": idx,
"Model": f"[{display_label}]({link})",
"Downloads": format_number(model['downloads']),
"Likes": format_number(model['likes']),
"Category": model['category'],
"Family": model['family'],
"License": model['license'],
"Access": model['access'],
"Params": model['param_size'],
"Created": model['created']
})
df = pd.DataFrame(df_data, columns=[
"Rank", "Model", "Downloads", "Likes", "Category", "Family",
"License", "Access", "Params", "Created"
])
return df
def create_ui():
"""Create the Gradio interface."""
families, licenses = get_filter_options()
family_choices = ["All"] + families
license_choices = ["All"] + licenses
access_choices = ["All", "Open", "Gated"]
with gr.Blocks(theme=gr.themes.Soft(), title="Model Rank") as app:
gr.Markdown("# Model Rank")
with gr.Row():
with gr.Column(scale=1):
category = gr.Radio(
choices=["All", "Base Model", "Fine-Tune", "Quant"],
value="All",
label="Category Filter",
info="Filter models by category"
)
family = gr.Dropdown(
choices=family_choices,
value="All",
label="Model Family",
allow_custom_value=False
)
license_filter = gr.Dropdown(
choices=license_choices,
value="All",
label="License",
allow_custom_value=False
)
param_size_min = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=0.1,
label="Min Parameter Size (B)",
info="Minimum parameter size in billions (0 = no minimum)"
)
param_size_max = gr.Slider(
minimum=0,
maximum=100,
value=100,
step=0.1,
label="Max Parameter Size (B)",
info="Maximum parameter size in billions (100 = no maximum)"
)
access = gr.Radio(
choices=access_choices,
value="All",
label="Access",
info="Filter by whether the model is open or gated"
)
hidden_gems = gr.Checkbox(
value=False,
label="Show Hidden Gems only",
info="Models with reproducibility tags, high likes/downloads ratio, and <2K downloads"
)
active_only = gr.Checkbox(
value=False,
label="Only Active Models",
info="Restrict to models with downloads in the selected recent window"
)
activity_window = gr.Radio(
choices=["7d", "14d", "30d", "90d"],
value="30d",
label="Activity Window",
info="Period for measuring recent downloads"
)
timeframe = gr.Radio(
choices=["Last Day", "Last Week", "Last Month", "Last 3 Months", "All Time"],
value="All Time",
label="Timeframe",
info="Filter by when model was created"
)
sort_by = gr.Radio(
choices=["downloads", "likes"],
value="downloads",
label="Sort By",
info="Sort models by downloads or likes"
)
max_results = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=1,
label="Max Results",
info="Number of models to display"
)
refresh_btn = gr.Button("Refresh Data", variant="primary")
with gr.Column(scale=3):
output = gr.Dataframe(
headers=["Rank", "Model", "Downloads", "Likes", "Category", "Family", "License", "Access", "Params", "Created"],
datatype=["number", "markdown", "str", "str", "str", "str", "str", "str", "str", "str"],
label="Models",
wrap=True,
interactive=False,
column_widths=["5%", "28%", "8%", "7%", "9%", "9%", "8%", "7%", "7%", "12%"]
)
# Event handlers
def update_table(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max):
return process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max)
def refresh_and_update(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max):
# Clear cache to force refresh
model_cache.clear()
return process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max)
inputs = [category, family, license_filter, access, hidden_gems, active_only, activity_window, timeframe, sort_by, max_results, param_size_min, param_size_max]
category.change(fn=update_table, inputs=inputs, outputs=output)
family.change(fn=update_table, inputs=inputs, outputs=output)
license_filter.change(fn=update_table, inputs=inputs, outputs=output)
access.change(fn=update_table, inputs=inputs, outputs=output)
hidden_gems.change(fn=update_table, inputs=inputs, outputs=output)
active_only.change(fn=update_table, inputs=inputs, outputs=output)
activity_window.change(fn=update_table, inputs=inputs, outputs=output)
timeframe.change(fn=update_table, inputs=inputs, outputs=output)
sort_by.change(fn=update_table, inputs=inputs, outputs=output)
max_results.change(fn=update_table, inputs=inputs, outputs=output)
refresh_btn.click(fn=refresh_and_update, inputs=inputs, outputs=output)
# Load initial data
app.load(fn=update_table, inputs=inputs, outputs=output)
return app
if __name__ == "__main__":
app = create_ui()
app.launch()