Spaces:
Running
Running
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| import pandas as pd | |
| from datetime import datetime, timedelta, timezone | |
| from types import SimpleNamespace | |
| import requests | |
| # Initialize HuggingFace API | |
| api = HfApi() | |
| MODEL_CACHE_DURATION = timedelta(hours=1) | |
| model_cache = {} | |
| def _cache_key(sort: str, last: str | None, limit: int) -> tuple[str, str, int]: | |
| return (sort, last or "all", limit) | |
| def _store_cache(sort: str, last: str | None, limit: int, models): | |
| model_cache[_cache_key(sort, last, limit)] = { | |
| 'models': models, | |
| 'timestamp': datetime.now() | |
| } | |
| def _get_cached(sort: str, last: str | None, limit: int): | |
| entry = model_cache.get(_cache_key(sort, last, limit)) | |
| if not entry: | |
| return None | |
| if datetime.now() - entry['timestamp'] > MODEL_CACHE_DURATION: | |
| return None | |
| return entry['models'] | |
| def _convert_model_payload(payload: dict): | |
| """Convert raw REST payload to a SimpleNamespace matching ModelInfo attributes.""" | |
| data = payload.copy() | |
| model_id = data.get('modelId') or data.get('id') | |
| data['modelId'] = model_id | |
| data.setdefault('tags', data.get('tags') or []) | |
| data.setdefault('likes', data.get('likes', 0) or 0) | |
| data.setdefault('downloads', data.get('downloads', 0) or 0) | |
| data.setdefault('gated', data.get('gated', False)) | |
| data.setdefault('private', data.get('private', False)) | |
| return SimpleNamespace(**data) | |
| def fetch_models(limit=500, sort="downloads", last: str | None = None): | |
| """Fetch models from HuggingFace Hub or REST API with caching.""" | |
| cached = _get_cached(sort, last, limit) | |
| if cached is not None: | |
| return cached | |
| try: | |
| if last is None: | |
| models = list(api.list_models( | |
| sort=sort, | |
| direction=-1, | |
| limit=limit, | |
| full=True | |
| )) | |
| else: | |
| params = { | |
| "sort": sort, | |
| "direction": -1, | |
| "limit": limit, | |
| "last": last, | |
| "full": "true" | |
| } | |
| response = requests.get( | |
| "https://huggingface.co/api/models", | |
| params=params, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| raw_models = response.json() | |
| models = [_convert_model_payload(payload) for payload in raw_models] | |
| _store_cache(sort, last, limit, models) | |
| return models | |
| except Exception as e: | |
| print(f"Error fetching models: {e}") | |
| return [] | |
| def categorize_model(model) -> str: | |
| """Categorize a model as base, fine-tune, or quant.""" | |
| model_id = model.modelId.lower() | |
| tags = [tag.lower() for tag in (model.tags or [])] | |
| # Check for quant indicators | |
| quant_patterns = [ | |
| 'gguf', 'gptq', 'awq', 'ggml', 'exl2', 'exllamav2', | |
| 'quantized', 'quant', '-q4', '-q5', '-q6', '-q8', | |
| 'k-quant', 'k_m', 'k_s', 'k_l' | |
| ] | |
| for pattern in quant_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Quant" | |
| # Check for fine-tune indicators | |
| finetune_patterns = [ | |
| 'finetune', 'fine-tune', 'ft', 'instruct', 'chat', | |
| 'dpo', 'rlhf', 'sft', 'lora', 'qlora' | |
| ] | |
| # Check if model name suggests it's a fine-tune (has multiple parts with descriptive names) | |
| parts = model_id.split('/')[-1].split('-') | |
| if len(parts) > 2: | |
| # Likely a fine-tune if it has descriptive suffixes | |
| for pattern in finetune_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Fine-Tune" | |
| # Check for base model indicators | |
| base_patterns = [ | |
| 'base', 'pretrained', 'pre-trained', 'foundation' | |
| ] | |
| for pattern in base_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Base Model" | |
| # Default heuristic: if it's from major organizations and doesn't have fine-tune indicators, likely base | |
| major_orgs = ['meta-llama', 'mistralai', 'google', 'microsoft', 'openai', 'facebook', 'tiiuae'] | |
| org = model_id.split('/')[0] | |
| if org in major_orgs and not any(pattern in model_id for pattern in finetune_patterns): | |
| return "Base Model" | |
| # Default to fine-tune for most other cases | |
| return "Fine-Tune" | |
| def extract_license(tags) -> str: | |
| """Extract primary license from model tags.""" | |
| if not tags: | |
| return "Unknown" | |
| for tag in tags: | |
| if tag.startswith("license:"): | |
| return tag.split(":", 1)[1].upper() | |
| return "Unknown" | |
| def extract_param_size(tags, model_id=None) -> str: | |
| """Extract parameter size from model tags or infer from model name.""" | |
| if not tags: | |
| size = "Unknown" | |
| else: | |
| size = "Unknown" | |
| for tag in tags: | |
| if tag.startswith("params:"): | |
| size = tag.split(":", 1)[1].upper() | |
| break | |
| if size == "Unknown" and model_id: | |
| import re | |
| # Common patterns in model names (case-insensitive) | |
| patterns = [ | |
| r'(\d+(?:\.\d+)?[BM])', # 7B, 70B, 1.5M, etc. | |
| r'(\d+)B', # 7B, 70B | |
| r'(\d+)M', # 1.5M | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, model_id, re.IGNORECASE) | |
| if match: | |
| num = match.group(1).upper() | |
| # Normalize format | |
| if 'B' in num: | |
| size = num.replace('B', 'B').upper() | |
| elif 'M' in num: | |
| size = num.replace('M', 'M').upper() | |
| break | |
| return size | |
| FAMILY_KEYWORDS = { | |
| "llama": "LLaMA", | |
| "mistral": "Mistral", | |
| "mixtral": "Mixtral", | |
| "phi": "Phi", | |
| "gemma": "Gemma", | |
| "qwen": "Qwen", | |
| "falcon": "Falcon", | |
| "yi": "Yi", | |
| "deepseek": "DeepSeek", | |
| "openelm": "OpenELM", | |
| "gpt-neox": "GPT-NeoX", | |
| "opt": "OPT", | |
| "command": "Command", | |
| } | |
| def detect_family(model_id: str, tags) -> str: | |
| """Heuristic to map model to a known family.""" | |
| lowered = model_id.lower() | |
| tag_join = " ".join((tags or [])).lower() | |
| for keyword, family in FAMILY_KEYWORDS.items(): | |
| if keyword in lowered or keyword in tag_join: | |
| return family | |
| return model_id.split("/")[-1].split("-")[0].title() | |
| def determine_access(model) -> str: | |
| """Return access status string for model.""" | |
| if getattr(model, "gated", False) or getattr(model, "private", False): | |
| return "Gated" | |
| return "Open" | |
| HIDDEN_GEM_DOWNLOAD_LIMIT = 2_000 | |
| HIDDEN_GEM_RATIO_THRESHOLD = 0.1 # likes per download | |
| REPRO_TAG_KEYWORDS = [ | |
| "reproducibility", | |
| "reproducible", | |
| "replicate", | |
| "benchmark", | |
| "leaderboard", | |
| "evaluation", | |
| "arxiv:", | |
| "paper", | |
| "paperswithcode" | |
| ] | |
| def has_reproducibility_signal(tags) -> bool: | |
| if not tags: | |
| return False | |
| tags_lower = [tag.lower() for tag in tags] | |
| for keyword in REPRO_TAG_KEYWORDS: | |
| if any(keyword in tag for tag in tags_lower): | |
| return True | |
| return False | |
| def is_hidden_gem(downloads: int, likes: int, tags) -> bool: | |
| if downloads is None: | |
| downloads = 0 | |
| if likes is None: | |
| likes = 0 | |
| if downloads == 0: | |
| ratio = likes | |
| else: | |
| ratio = likes / downloads | |
| return ( | |
| downloads < HIDDEN_GEM_DOWNLOAD_LIMIT and | |
| ratio >= HIDDEN_GEM_RATIO_THRESHOLD and | |
| has_reproducibility_signal(tags) | |
| ) | |
| def get_filter_options(): | |
| """Return sorted unique families and licenses for UI controls.""" | |
| models = fetch_models() | |
| families = set() | |
| licenses = set() | |
| for model in models: | |
| families.add(detect_family(model.modelId, model.tags or [])) | |
| licenses.add(extract_license(model.tags or [])) | |
| families.discard("") | |
| licenses.discard("") | |
| return sorted(families), sorted(licenses) | |
| def format_number(num): | |
| """Format large numbers for display.""" | |
| try: | |
| num = int(num) | |
| except Exception: | |
| return "0" | |
| if num >= 1_000_000: | |
| return f"{num/1_000_000:.1f}M" | |
| elif num >= 1_000: | |
| return f"{num/1_000:.1f}K" | |
| return str(num) | |
| def process_models(category_filter="All", family_filter="All", license_filter="All", | |
| access_filter="All", hidden_gems_only=False, | |
| sort_by="downloads", max_results=50, timeframe="All Time", | |
| active_only=False, activity_window="30d", param_size_min=None, param_size_max=None): | |
| """Process and filter models based on category, metadata filters, and sort preference.""" | |
| fetch_limit = 1000 if hidden_gems_only or active_only else 500 | |
| last_param = activity_window if active_only else None | |
| sort_param = "likes" if hidden_gems_only else sort_by | |
| models = fetch_models(limit=fetch_limit, sort=sort_param, last=last_param) | |
| if not models: | |
| return pd.DataFrame(columns=[ | |
| "Rank", "Model", "Downloads", "Likes", "Category", | |
| "Family", "License", "Access", "Created" | |
| ]) | |
| # Calculate timeframe cutoff | |
| now = datetime.now(timezone.utc) | |
| timeframe_cutoffs = { | |
| "Last Day": now - timedelta(days=1), | |
| "Last Week": now - timedelta(weeks=1), | |
| "Last Month": now - timedelta(days=30), | |
| "Last 3 Months": now - timedelta(days=90), | |
| "All Time": None | |
| } | |
| cutoff_date = timeframe_cutoffs.get(timeframe) | |
| # Process model data | |
| model_data = [] | |
| for model in models: | |
| # Determine model date | |
| model_date = getattr(model, 'createdAt', None) or getattr(model, 'lastModified', None) | |
| if model_date is not None: | |
| if isinstance(model_date, datetime): | |
| if model_date.tzinfo is None: | |
| model_date = model_date.replace(tzinfo=timezone.utc) | |
| else: | |
| model_date = model_date.astimezone(timezone.utc) | |
| else: | |
| try: | |
| iso_str = str(model_date) | |
| if iso_str.endswith('Z'): | |
| iso_str = iso_str[:-1] + '+00:00' | |
| model_date = datetime.fromisoformat(iso_str) | |
| if model_date.tzinfo is None: | |
| model_date = model_date.replace(tzinfo=timezone.utc) | |
| else: | |
| model_date = model_date.astimezone(timezone.utc) | |
| except Exception: | |
| model_date = None | |
| if cutoff_date is not None: | |
| if model_date is None or model_date < cutoff_date: | |
| continue | |
| license_tag = extract_license(model.tags or []) | |
| family = detect_family(model.modelId, model.tags or []) | |
| access = determine_access(model) | |
| hidden_gem = is_hidden_gem(getattr(model, 'downloads', 0), getattr(model, 'likes', 0), model.tags or []) | |
| category = categorize_model(model) | |
| param_size = extract_param_size(model.tags or [], model.modelId) | |
| # Apply filters | |
| if category_filter != "All" and category != category_filter: | |
| continue | |
| if family_filter != "All" and family != family_filter: | |
| continue | |
| if license_filter != "All" and license_tag != license_filter: | |
| continue | |
| if access_filter != "All" and access != access_filter: | |
| continue | |
| if hidden_gems_only and not hidden_gem: | |
| continue | |
| # Apply parameter size filter if specified | |
| if param_size_min is not None and param_size_min > 0: | |
| if param_size == "Unknown": | |
| continue | |
| try: | |
| # Convert parameter size to billions for comparison | |
| if param_size.endswith('B'): | |
| size_b = float(param_size[:-1]) | |
| elif param_size.endswith('M'): | |
| size_b = float(param_size[:-1]) / 1000 | |
| else: | |
| continue | |
| if size_b < param_size_min: | |
| continue | |
| except ValueError: | |
| continue | |
| if param_size_max is not None and param_size_max < 100: | |
| if param_size == "Unknown": | |
| continue | |
| try: | |
| # Convert parameter size to billions for comparison | |
| if param_size.endswith('B'): | |
| size_b = float(param_size[:-1]) | |
| elif param_size.endswith('M'): | |
| size_b = float(param_size[:-1]) / 1000 | |
| else: | |
| continue | |
| if size_b > param_size_max: | |
| continue | |
| except ValueError: | |
| continue | |
| downloads = getattr(model, 'downloads', 0) or 0 | |
| likes = getattr(model, 'likes', 0) or 0 | |
| model_id = model.modelId | |
| author = model_id.split('/')[0] if '/' in model_id else 'N/A' | |
| name = model_id.split('/')[-1] if '/' in model_id else model_id | |
| created_str = model_date.strftime("%Y-%m-%d") if model_date else "N/A" | |
| model_data.append({ | |
| 'model_id': model_id, | |
| 'downloads': downloads, | |
| 'likes': likes, | |
| 'category': category, | |
| 'author': author, | |
| 'name': name, | |
| 'created': created_str, | |
| 'license': license_tag, | |
| 'family': family, | |
| 'access': access, | |
| 'hidden_gem': hidden_gem, | |
| 'param_size': param_size, | |
| }) | |
| # Sort models | |
| if sort_by == "downloads": | |
| model_data.sort(key=lambda x: x['downloads'], reverse=True) | |
| elif sort_by == "likes": | |
| model_data.sort(key=lambda x: x['likes'], reverse=True) | |
| else: | |
| model_data.sort(key=lambda x: x['downloads'], reverse=True) | |
| # Limit results | |
| model_data = model_data[:int(max_results) if max_results is not None else 50] | |
| # Create DataFrame for display | |
| df_data = [] | |
| for idx, model in enumerate(model_data, 1): | |
| gem_badge = "💎 " if model['hidden_gem'] else "" | |
| display_label = f"{gem_badge}{model['model_id']}" | |
| link = f"https://huggingface.co/{model['model_id']}" | |
| df_data.append({ | |
| "Rank": idx, | |
| "Model": f"[{display_label}]({link})", | |
| "Downloads": format_number(model['downloads']), | |
| "Likes": format_number(model['likes']), | |
| "Category": model['category'], | |
| "Family": model['family'], | |
| "License": model['license'], | |
| "Access": model['access'], | |
| "Params": model['param_size'], | |
| "Created": model['created'] | |
| }) | |
| df = pd.DataFrame(df_data, columns=[ | |
| "Rank", "Model", "Downloads", "Likes", "Category", "Family", | |
| "License", "Access", "Params", "Created" | |
| ]) | |
| return df | |
| def create_ui(): | |
| """Create the Gradio interface.""" | |
| families, licenses = get_filter_options() | |
| family_choices = ["All"] + families | |
| license_choices = ["All"] + licenses | |
| access_choices = ["All", "Open", "Gated"] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Model Rank") as app: | |
| gr.Markdown("# Model Rank") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| category = gr.Radio( | |
| choices=["All", "Base Model", "Fine-Tune", "Quant"], | |
| value="All", | |
| label="Category Filter", | |
| info="Filter models by category" | |
| ) | |
| family = gr.Dropdown( | |
| choices=family_choices, | |
| value="All", | |
| label="Model Family", | |
| allow_custom_value=False | |
| ) | |
| license_filter = gr.Dropdown( | |
| choices=license_choices, | |
| value="All", | |
| label="License", | |
| allow_custom_value=False | |
| ) | |
| param_size_min = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=0, | |
| step=0.1, | |
| label="Min Parameter Size (B)", | |
| info="Minimum parameter size in billions (0 = no minimum)" | |
| ) | |
| param_size_max = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=100, | |
| step=0.1, | |
| label="Max Parameter Size (B)", | |
| info="Maximum parameter size in billions (100 = no maximum)" | |
| ) | |
| access = gr.Radio( | |
| choices=access_choices, | |
| value="All", | |
| label="Access", | |
| info="Filter by whether the model is open or gated" | |
| ) | |
| hidden_gems = gr.Checkbox( | |
| value=False, | |
| label="Show Hidden Gems only", | |
| info="Models with reproducibility tags, high likes/downloads ratio, and <2K downloads" | |
| ) | |
| active_only = gr.Checkbox( | |
| value=False, | |
| label="Only Active Models", | |
| info="Restrict to models with downloads in the selected recent window" | |
| ) | |
| activity_window = gr.Radio( | |
| choices=["7d", "14d", "30d", "90d"], | |
| value="30d", | |
| label="Activity Window", | |
| info="Period for measuring recent downloads" | |
| ) | |
| timeframe = gr.Radio( | |
| choices=["Last Day", "Last Week", "Last Month", "Last 3 Months", "All Time"], | |
| value="All Time", | |
| label="Timeframe", | |
| info="Filter by when model was created" | |
| ) | |
| sort_by = gr.Radio( | |
| choices=["downloads", "likes"], | |
| value="downloads", | |
| label="Sort By", | |
| info="Sort models by downloads or likes" | |
| ) | |
| max_results = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Max Results", | |
| info="Number of models to display" | |
| ) | |
| refresh_btn = gr.Button("Refresh Data", variant="primary") | |
| with gr.Column(scale=3): | |
| output = gr.Dataframe( | |
| headers=["Rank", "Model", "Downloads", "Likes", "Category", "Family", "License", "Access", "Params", "Created"], | |
| datatype=["number", "markdown", "str", "str", "str", "str", "str", "str", "str", "str"], | |
| label="Models", | |
| wrap=True, | |
| interactive=False, | |
| column_widths=["5%", "28%", "8%", "7%", "9%", "9%", "8%", "7%", "7%", "12%"] | |
| ) | |
| # Event handlers | |
| def update_table(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max): | |
| return process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max) | |
| def refresh_and_update(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max): | |
| # Clear cache to force refresh | |
| model_cache.clear() | |
| return process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max) | |
| inputs = [category, family, license_filter, access, hidden_gems, active_only, activity_window, timeframe, sort_by, max_results, param_size_min, param_size_max] | |
| category.change(fn=update_table, inputs=inputs, outputs=output) | |
| family.change(fn=update_table, inputs=inputs, outputs=output) | |
| license_filter.change(fn=update_table, inputs=inputs, outputs=output) | |
| access.change(fn=update_table, inputs=inputs, outputs=output) | |
| hidden_gems.change(fn=update_table, inputs=inputs, outputs=output) | |
| active_only.change(fn=update_table, inputs=inputs, outputs=output) | |
| activity_window.change(fn=update_table, inputs=inputs, outputs=output) | |
| timeframe.change(fn=update_table, inputs=inputs, outputs=output) | |
| sort_by.change(fn=update_table, inputs=inputs, outputs=output) | |
| max_results.change(fn=update_table, inputs=inputs, outputs=output) | |
| refresh_btn.click(fn=refresh_and_update, inputs=inputs, outputs=output) | |
| # Load initial data | |
| app.load(fn=update_table, inputs=inputs, outputs=output) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_ui() | |
| app.launch() |