|
import os |
|
import platform |
|
import re |
|
from collections import defaultdict |
|
|
|
import gradio as gr |
|
from cachetools import TTLCache, cached |
|
from cytoolz import groupby |
|
from huggingface_hub import CollectionItem, get_collection, list_datasets, list_models |
|
from tqdm.auto import tqdm |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from apscheduler.triggers.cron import CronTrigger |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
is_macos = platform.system() == "Darwin" |
|
local = platform.system() == "Darwin" |
|
LIMIT = 1000 if is_macos else None |
|
CACHE_TIME = 60 * 15 |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def get_models(): |
|
print("getting models...") |
|
return list(tqdm(iter(list_models(full=True, limit=LIMIT)))) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def get_datasets(): |
|
print("getting datasets...") |
|
return list(tqdm(iter(list_datasets(full=True, limit=LIMIT)))) |
|
|
|
|
|
|
|
get_models() |
|
get_datasets() |
|
|
|
|
|
def check_for_arxiv_id(model): |
|
return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False |
|
|
|
|
|
def extract_arxiv_id(input_string: str) -> str: |
|
pattern = re.compile(r"\barxiv:(\d+\.\d+)\b") |
|
match = pattern.search(input_string) |
|
return match[1] if match else None |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def create_model_to_arxiv_id_dict(): |
|
models = get_models() |
|
model_to_arxiv_id = {} |
|
for model in models: |
|
if arxiv_papers := check_for_arxiv_id(model): |
|
clean_arxiv_ids = [] |
|
for paper in arxiv_papers: |
|
if arxiv_id := extract_arxiv_id(paper): |
|
clean_arxiv_ids.append(arxiv_id) |
|
model_to_arxiv_id[model.modelId] = clean_arxiv_ids |
|
return model_to_arxiv_id |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def create_dataset_to_arxiv_id_dict(): |
|
datasets = get_datasets() |
|
dataset_to_arxiv_id = {} |
|
for dataset in datasets: |
|
if arxiv_papers := check_for_arxiv_id(dataset): |
|
clean_arxiv_ids = [] |
|
for paper in arxiv_papers: |
|
if arxiv_id := extract_arxiv_id(paper): |
|
clean_arxiv_ids.append(arxiv_id) |
|
dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids |
|
return dataset_to_arxiv_id |
|
|
|
|
|
def get_collection_type(collection_item: CollectionItem): |
|
try: |
|
return f"{collection_item.item_type}s" |
|
except AttributeError: |
|
return None |
|
|
|
|
|
def group_collection_items(collection_slug: str): |
|
collection = get_collection(collection_slug) |
|
items = collection.items |
|
return groupby(get_collection_type, items) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME)) |
|
def get_papers_for_collection(collection_slug: str): |
|
dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict() |
|
models_to_arxiv_id = create_model_to_arxiv_id_dict() |
|
collection = group_collection_items(collection_slug) |
|
collection_datasets = collection.get("datasets", None) |
|
collection_models = collection.get("models", None) |
|
papers = collection.get("papers", None) |
|
dataset_papers = defaultdict(dict) |
|
model_papers = defaultdict(dict) |
|
collection_papers = defaultdict(dict) |
|
if collection_datasets is not None: |
|
for dataset in collection_datasets: |
|
if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None): |
|
data = { |
|
"arxiv_ids": arxiv_ids, |
|
"hub_paper_links": [ |
|
f"https://huggingface.co/papers/{arxiv_id}" |
|
for arxiv_id in arxiv_ids |
|
], |
|
} |
|
dataset_papers[dataset.item_id] = data |
|
if collection_models is not None: |
|
for model in collection.get("models", []): |
|
if arxiv_ids := models_to_arxiv_id.get(model.item_id, None): |
|
data = { |
|
"arxiv_ids": arxiv_ids, |
|
"hub_paper_links": [ |
|
f"https://huggingface.co/papers/{arxiv_id}" |
|
for arxiv_id in arxiv_ids |
|
], |
|
} |
|
model_papers[model.item_id] = data |
|
if papers is not None: |
|
for paper in papers: |
|
data = { |
|
"arxiv_ids": [paper.item_id], |
|
"hub_paper_links": [f"https://huggingface.co/papers/{paper.item_id}"], |
|
} |
|
collection_papers[paper.item_id] = data |
|
if not dataset_papers: |
|
dataset_papers = None |
|
if not model_papers: |
|
model_papers = None |
|
if not collection_papers: |
|
collection_papers = None |
|
return { |
|
"dataset papers": dataset_papers, |
|
"model papers": model_papers, |
|
"papers": collection_papers, |
|
} |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(get_datasets, "interval", minutes=15) |
|
scheduler.add_job(get_models, "interval", minutes=15) |
|
scheduler.start() |
|
|
|
placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9" |
|
slug_input = gr.Textbox( |
|
placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1 |
|
) |
|
description = ( |
|
"Enter a Collection slug to get the arXiv IDs and Hugging Face Paper links for" |
|
" papers associated with models and datasets in the collection. If the collection" |
|
" includes papers the arXiv IDs and Hugging Face Paper links will be returned for" |
|
" those papers as well." |
|
) |
|
|
|
examples = [ |
|
placeholder_url, |
|
"davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b", |
|
] |
|
|
|
|
|
gr.Interface( |
|
get_papers_for_collection, |
|
slug_input, |
|
"json", |
|
title="ππ: Extract linked papers from a Hugging Face Collection", |
|
description=description, |
|
examples=examples, |
|
cache_examples=True, |
|
).queue(concurrency_count=4).launch() |
|
|