|
from huggingface_hub import list_datasets, list_models |
|
from cachetools import TTLCache, cached |
|
import platform |
|
import re |
|
import gradio as gr |
|
from huggingface_hub import get_collection |
|
from cytoolz import groupby |
|
from collections import defaultdict |
|
import os |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
is_macos = platform.system() == "Darwin" |
|
LIMIT = None |
|
CACHE_TIME = 60 * 5 |
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def get_models(): |
|
return list(iter(list_models(full=True, limit=LIMIT))) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def get_datasets(): |
|
return list(iter(list_datasets(full=True, limit=LIMIT))) |
|
|
|
|
|
get_models() |
|
get_datasets() |
|
|
|
|
|
def check_for_arxiv_id(model): |
|
return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False |
|
|
|
|
|
def extract_arxiv_id(input_string: str) -> str: |
|
|
|
pattern = re.compile(r"\barxiv:(\d+\.\d+)\b") |
|
|
|
|
|
match = pattern.search(input_string) |
|
|
|
|
|
return match[1] if match else None |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def create_model_to_arxiv_id_dict(): |
|
models = get_models() |
|
model_to_arxiv_id = {} |
|
for model in models: |
|
if arxiv_papers := check_for_arxiv_id(model): |
|
clean_arxiv_ids = [] |
|
for paper in arxiv_papers: |
|
if arxiv_id := extract_arxiv_id(paper): |
|
clean_arxiv_ids.append(arxiv_id) |
|
model_to_arxiv_id[model.modelId] = clean_arxiv_ids |
|
return model_to_arxiv_id |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) |
|
def create_dataset_to_arxiv_id_dict(): |
|
datasets = get_datasets() |
|
dataset_to_arxiv_id = {} |
|
for dataset in datasets: |
|
if arxiv_papers := check_for_arxiv_id(dataset): |
|
clean_arxiv_ids = [] |
|
for paper in arxiv_papers: |
|
if arxiv_id := extract_arxiv_id(paper): |
|
clean_arxiv_ids.append(arxiv_id) |
|
dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids |
|
return dataset_to_arxiv_id |
|
|
|
|
|
url = "lunarflu/ai-podcasts-and-talks-65119866353a60593bf99c58" |
|
|
|
|
|
def group_collection_items(collection_slug: str): |
|
collection = get_collection(collection_slug) |
|
items = collection.items |
|
return groupby(lambda x: f"{x.repoType}s", items) |
|
|
|
|
|
def get_papers_for_collection(collection_slug: str): |
|
dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict() |
|
models_to_arxiv_id = create_model_to_arxiv_id_dict() |
|
collection = group_collection_items(collection_slug) |
|
collection_datasets = collection.get("datasets", None) |
|
collection_models = collection.get("models", None) |
|
dataset_papers = defaultdict(dict) |
|
model_papers = defaultdict(dict) |
|
if collection_datasets is not None: |
|
for dataset in collection_datasets: |
|
if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None): |
|
data = { |
|
"arxiv_ids": arxiv_ids, |
|
"hub_paper_links": [ |
|
f"https://huggingface.co/papers/{arxiv_id}" |
|
for arxiv_id in arxiv_ids |
|
], |
|
} |
|
dataset_papers[dataset.item_id] = data |
|
if collection_models is not None: |
|
for model in collection.get("models", []): |
|
if arxiv_ids := models_to_arxiv_id.get(model.item_id, None): |
|
data = { |
|
"arxiv_ids": arxiv_ids, |
|
"hub_paper_links": [ |
|
f"https://huggingface.co/papers/{arxiv_id}" |
|
for arxiv_id in arxiv_ids |
|
], |
|
} |
|
model_papers[model.item_id] = data |
|
return {"datasets": dataset_papers, "models": model_papers} |
|
|
|
|
|
url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9" |
|
|
|
gr.Interface(get_papers_for_collection, "text", "json").launch() |
|
|