File size: 5,839 Bytes
0750ab5
310dfa4
 
0750ab5
 
310dfa4
0750ab5
310dfa4
0f6c693
fd8aa7c
212995a
 
0750ab5
9ce3fb9
310dfa4
212995a
479b4a3
212995a
0750ab5
310dfa4
206a7be
310dfa4
212995a
fd8aa7c
310dfa4
 
206a7be
310dfa4
212995a
fd8aa7c
310dfa4
9ce3fb9
 
 
 
310dfa4
 
 
 
 
 
 
 
 
 
 
206a7be
310dfa4
 
 
 
 
 
 
 
 
 
 
 
 
206a7be
310dfa4
 
 
 
 
 
 
 
 
 
 
 
 
479b4a3
 
 
 
 
310dfa4
 
 
 
 
479b4a3
310dfa4
212995a
adf5d9a
310dfa4
 
 
 
 
 
479b4a3
1eb4ea3
 
479b4a3
310dfa4
 
1eb4ea3
310dfa4
1eb4ea3
 
 
 
 
310dfa4
1eb4ea3
310dfa4
 
9ce3fb9
310dfa4
9ce3fb9
 
 
 
 
310dfa4
1eb4ea3
479b4a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310dfa4
 
212995a
 
 
 
 
c1af0d8
 
 
 
a0915df
0f6c693
479b4a3
0f6c693
479b4a3
a0915df
310dfa4
a0915df
 
 
 
212995a
 
a0915df
 
 
 
479b4a3
a0915df
 
 
adf5d9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import platform
import re
from collections import defaultdict

import gradio as gr
from cachetools import TTLCache, cached
from cytoolz import groupby
from huggingface_hub import CollectionItem, get_collection, list_datasets, list_models
from tqdm.auto import tqdm
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
is_macos = platform.system() == "Darwin"
local = platform.system() == "Darwin"
LIMIT = 1000 if is_macos else None  # limit for local dev because slooow internet
CACHE_TIME = 60 * 15  # 15 minutes


@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_models():
    print("getting models...")
    return list(tqdm(iter(list_models(full=True, limit=LIMIT))))


@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_datasets():
    print("getting datasets...")
    return list(tqdm(iter(list_datasets(full=True, limit=LIMIT))))


get_models()  # warm up the cache
get_datasets()  # warm up the cache


def check_for_arxiv_id(model):
    return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False


def extract_arxiv_id(input_string: str) -> str:
    pattern = re.compile(r"\barxiv:(\d+\.\d+)\b")
    match = pattern.search(input_string)
    return match[1] if match else None


@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_model_to_arxiv_id_dict():
    models = get_models()
    model_to_arxiv_id = {}
    for model in models:
        if arxiv_papers := check_for_arxiv_id(model):
            clean_arxiv_ids = []
            for paper in arxiv_papers:
                if arxiv_id := extract_arxiv_id(paper):
                    clean_arxiv_ids.append(arxiv_id)
            model_to_arxiv_id[model.modelId] = clean_arxiv_ids
    return model_to_arxiv_id


@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_dataset_to_arxiv_id_dict():
    datasets = get_datasets()
    dataset_to_arxiv_id = {}
    for dataset in datasets:
        if arxiv_papers := check_for_arxiv_id(dataset):
            clean_arxiv_ids = []
            for paper in arxiv_papers:
                if arxiv_id := extract_arxiv_id(paper):
                    clean_arxiv_ids.append(arxiv_id)
            dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids
    return dataset_to_arxiv_id


def get_collection_type(collection_item: CollectionItem):
    try:
        return f"{collection_item.item_type}s"
    except AttributeError:
        return None


def group_collection_items(collection_slug: str):
    collection = get_collection(collection_slug)
    items = collection.items
    return groupby(get_collection_type, items)


@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME))
def get_papers_for_collection(collection_slug: str):
    dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict()
    models_to_arxiv_id = create_model_to_arxiv_id_dict()
    collection = group_collection_items(collection_slug)
    collection_datasets = collection.get("datasets", None)
    collection_models = collection.get("models", None)
    papers = collection.get("papers", None)
    dataset_papers = defaultdict(dict)
    model_papers = defaultdict(dict)
    collection_papers = defaultdict(dict)
    if collection_datasets is not None:
        for dataset in collection_datasets:
            if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None):
                data = {
                    "arxiv_ids": arxiv_ids,
                    "hub_paper_links": [
                        f"https://huggingface.co/papers/{arxiv_id}"
                        for arxiv_id in arxiv_ids
                    ],
                }
                dataset_papers[dataset.item_id] = data
    if collection_models is not None:
        for model in collection.get("models", []):
            if arxiv_ids := models_to_arxiv_id.get(model.item_id, None):
                data = {
                    "arxiv_ids": arxiv_ids,
                    "hub_paper_links": [
                        f"https://huggingface.co/papers/{arxiv_id}"
                        for arxiv_id in arxiv_ids
                    ],
                }
                model_papers[model.item_id] = data
    if papers is not None:
        for paper in papers:
            data = {
                "arxiv_ids": paper.item_id,
                "hub_paper_links": [f"https://huggingface.co/papers/{paper.item_id}"],
            }
            collection_papers[paper.item_id] = data
    if not dataset_papers:
        dataset_papers = None
    if not model_papers:
        model_papers = None
    if not collection_papers:
        collection_papers = None
    return {
        "dataset papers": dataset_papers,
        "model papers": model_papers,
        "papers": collection_papers,
    }


scheduler = BackgroundScheduler()
scheduler.add_job(get_datasets, "interval", minutes=15)
scheduler.add_job(get_models, "interval", minutes=15)
scheduler.start()

placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9"
slug_input = gr.Textbox(
    placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1
)
description = (
    "Enter a Collection slug to get the arXiv IDs and Hugging Face Paper links for"
    " papers associated with models and datasets in the collection. If the collection"
    " includes papers the arXiv IDs and Hugging Face Paper links will be returned for"
    " those papers as well."
)

examples = [
    placeholder_url,
    "davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b",
]


gr.Interface(
    get_papers_for_collection,
    slug_input,
    "json",
    title="πŸ“„πŸ”—: Extract linked papers from a Hugging Face Collection",
    description=description,
    examples=examples,
    cache_examples=True,
).queue(concurrency_count=4).launch()