import os import platform import re from collections import defaultdict import gradio as gr from cachetools import TTLCache, cached from cytoolz import groupby from huggingface_hub import get_collection, list_datasets, list_models from tqdm.auto import tqdm os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" is_macos = platform.system() == "Darwin" LIMIT = None CACHE_TIME = 60 * 5 # 5 minutes @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def get_models(): return list(tqdm(iter(list_models(full=True, limit=LIMIT)))) @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def get_datasets(): return list(tqdm(iter(list_datasets(full=True, limit=LIMIT)))) get_models() # warm up the cache get_datasets() # warm up the cache def check_for_arxiv_id(model): return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False def extract_arxiv_id(input_string: str) -> str: # Define the regular expression pattern pattern = re.compile(r"\barxiv:(\d+\.\d+)\b") # Search for the pattern in the input string match = pattern.search(input_string) # If a match is found, return the numeric part of the ARXIV ID, else return None return match[1] if match else None @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def create_model_to_arxiv_id_dict(): models = get_models() model_to_arxiv_id = {} for model in models: if arxiv_papers := check_for_arxiv_id(model): clean_arxiv_ids = [] for paper in arxiv_papers: if arxiv_id := extract_arxiv_id(paper): clean_arxiv_ids.append(arxiv_id) model_to_arxiv_id[model.modelId] = clean_arxiv_ids return model_to_arxiv_id @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def create_dataset_to_arxiv_id_dict(): datasets = get_datasets() dataset_to_arxiv_id = {} for dataset in datasets: if arxiv_papers := check_for_arxiv_id(dataset): clean_arxiv_ids = [] for paper in arxiv_papers: if arxiv_id := extract_arxiv_id(paper): clean_arxiv_ids.append(arxiv_id) dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids return dataset_to_arxiv_id placeholder_url = "lunarflu/ai-podcasts-and-talks-65119866353a60593bf99c58" def group_collection_items(collection_slug: str): collection = get_collection(collection_slug) items = collection.items return groupby(lambda x: f"{x.repoType}s", items) def get_papers_for_collection(collection_slug: str): dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict() models_to_arxiv_id = create_model_to_arxiv_id_dict() collection = group_collection_items(collection_slug) collection_datasets = collection.get("datasets", None) collection_models = collection.get("models", None) dataset_papers = defaultdict(dict) model_papers = defaultdict(dict) if collection_datasets is not None: for dataset in collection_datasets: if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None): data = { "arxiv_ids": arxiv_ids, "hub_paper_links": [ f"https://huggingface.co/papers/{arxiv_id}" for arxiv_id in arxiv_ids ], } dataset_papers[dataset.item_id] = data if collection_models is not None: for model in collection.get("models", []): if arxiv_ids := models_to_arxiv_id.get(model.item_id, None): data = { "arxiv_ids": arxiv_ids, "hub_paper_links": [ f"https://huggingface.co/papers/{arxiv_id}" for arxiv_id in arxiv_ids ], } model_papers[model.item_id] = data return {"dataset papers": dataset_papers, "model papers": model_papers} placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9" slug_input = gr.Textbox( placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1 ) description = ( "Enter a collection slug to get a list of papers associated with models and" " datasets in the collection." ) examples = [ placeholder_url, "davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b", ] gr.Interface( get_papers_for_collection, slug_input, "json", description=description, examples=examples, cache_examples=True, ).launch()