Spaces:
Runtime error
Runtime error
File size: 5,842 Bytes
0750ab5 310dfa4 0750ab5 310dfa4 0750ab5 310dfa4 0f6c693 fd8aa7c 212995a 0750ab5 9ce3fb9 310dfa4 212995a 479b4a3 212995a 0750ab5 310dfa4 206a7be 310dfa4 212995a fd8aa7c 310dfa4 206a7be 310dfa4 212995a fd8aa7c 310dfa4 9ce3fb9 6516013 9ce3fb9 310dfa4 206a7be 310dfa4 206a7be 310dfa4 479b4a3 310dfa4 479b4a3 310dfa4 212995a adf5d9a 310dfa4 479b4a3 1eb4ea3 479b4a3 310dfa4 1eb4ea3 310dfa4 1eb4ea3 310dfa4 1eb4ea3 310dfa4 9ce3fb9 310dfa4 9ce3fb9 310dfa4 1eb4ea3 479b4a3 6516013 479b4a3 310dfa4 212995a c1af0d8 a0915df 0f6c693 479b4a3 0f6c693 479b4a3 a0915df 310dfa4 a0915df 212995a a0915df 479b4a3 a0915df adf5d9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import platform
import re
from collections import defaultdict
import gradio as gr
from cachetools import TTLCache, cached
from cytoolz import groupby
from huggingface_hub import CollectionItem, get_collection, list_datasets, list_models
from tqdm.auto import tqdm
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
is_macos = platform.system() == "Darwin"
local = platform.system() == "Darwin"
LIMIT = 1000 if is_macos else None # limit for local dev because slooow internet
CACHE_TIME = 60 * 15 # 15 minutes
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_models():
print("getting models...")
return list(tqdm(iter(list_models(full=True, limit=LIMIT))))
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_datasets():
print("getting datasets...")
return list(tqdm(iter(list_datasets(full=True, limit=LIMIT))))
get_models() # warm up the cache
get_datasets() # warm up the cache
def check_for_arxiv_id(model):
return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False
def extract_arxiv_id(input_string: str) -> str:
pattern = re.compile(r"\barxiv:(\d+\.\d+)\b")
match = pattern.search(input_string)
return match[1] if match else None
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_model_to_arxiv_id_dict():
models = get_models()
model_to_arxiv_id = {}
for model in models:
if arxiv_papers := check_for_arxiv_id(model):
clean_arxiv_ids = []
for paper in arxiv_papers:
if arxiv_id := extract_arxiv_id(paper):
clean_arxiv_ids.append(arxiv_id)
model_to_arxiv_id[model.modelId] = clean_arxiv_ids
return model_to_arxiv_id
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_dataset_to_arxiv_id_dict():
datasets = get_datasets()
dataset_to_arxiv_id = {}
for dataset in datasets:
if arxiv_papers := check_for_arxiv_id(dataset):
clean_arxiv_ids = []
for paper in arxiv_papers:
if arxiv_id := extract_arxiv_id(paper):
clean_arxiv_ids.append(arxiv_id)
dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids
return dataset_to_arxiv_id
def get_collection_type(collection_item: CollectionItem):
try:
return f"{collection_item.item_type}s"
except AttributeError:
return None
def group_collection_items(collection_slug: str):
collection = get_collection(collection_slug)
items = collection.items
return groupby(get_collection_type, items)
@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME))
def get_papers_for_collection(collection_slug: str):
dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict()
models_to_arxiv_id = create_model_to_arxiv_id_dict()
collection = group_collection_items(collection_slug)
collection_datasets = collection.get("datasets", None)
collection_models = collection.get("models", None)
papers = collection.get("papers", None)
dataset_papers = defaultdict(dict)
model_papers = defaultdict(dict)
collection_papers = defaultdict(dict)
if collection_datasets is not None:
for dataset in collection_datasets:
if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None):
data = {
"arxiv_ids": arxiv_ids,
"hub_paper_links": [
f"https://huggingface.co/papers/{arxiv_id}"
for arxiv_id in arxiv_ids
],
}
dataset_papers[dataset.item_id] = data
if collection_models is not None:
for model in collection.get("models", []):
if arxiv_ids := models_to_arxiv_id.get(model.item_id, None):
data = {
"arxiv_ids": arxiv_ids,
"hub_paper_links": [
f"https://huggingface.co/papers/{arxiv_id}"
for arxiv_id in arxiv_ids
],
}
model_papers[model.item_id] = data
if papers is not None:
for paper in papers:
data = {
"arxiv_ids": [paper.item_id],
"hub_paper_links": [f"https://huggingface.co/papers/{paper.item_id}"],
}
collection_papers[paper.item_id] = data
if not dataset_papers:
dataset_papers = None
if not model_papers:
model_papers = None
if not collection_papers:
collection_papers = None
return {
"dataset papers": dataset_papers,
"model papers": model_papers,
"papers": collection_papers,
}
scheduler = BackgroundScheduler()
scheduler.add_job(get_datasets, "interval", minutes=15)
scheduler.add_job(get_models, "interval", minutes=15)
scheduler.start()
placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9"
slug_input = gr.Textbox(
placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1
)
description = (
"Enter a Collection slug to get the arXiv IDs and Hugging Face Paper links for"
" papers associated with models and datasets in the collection. If the collection"
" includes papers the arXiv IDs and Hugging Face Paper links will be returned for"
" those papers as well."
)
examples = [
placeholder_url,
"davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b",
]
gr.Interface(
get_papers_for_collection,
slug_input,
"json",
title="ππ: Extract linked papers from a Hugging Face Collection",
description=description,
examples=examples,
cache_examples=True,
).queue(concurrency_count=4).launch()
|