davanstrien's picture
davanstrien HF staff
make cache time global variable
206a7be
raw
history blame
No virus
4.04 kB
from huggingface_hub import list_datasets, list_models
from cachetools import TTLCache, cached
import platform
import re
import gradio as gr
from huggingface_hub import get_collection
from cytoolz import groupby
from collections import defaultdict
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
is_macos = platform.system() == "Darwin"
LIMIT = None
CACHE_TIME = 60 * 5
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_models():
return list(iter(list_models(full=True, limit=LIMIT)))
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def get_datasets():
return list(iter(list_datasets(full=True, limit=LIMIT)))
get_models() # warm up the cache
get_datasets() # warm up the cache
def check_for_arxiv_id(model):
return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False
def extract_arxiv_id(input_string: str) -> str:
# Define the regular expression pattern
pattern = re.compile(r"\barxiv:(\d+\.\d+)\b")
# Search for the pattern in the input string
match = pattern.search(input_string)
# If a match is found, return the numeric part of the ARXIV ID, else return None
return match[1] if match else None
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_model_to_arxiv_id_dict():
models = get_models()
model_to_arxiv_id = {}
for model in models:
if arxiv_papers := check_for_arxiv_id(model):
clean_arxiv_ids = []
for paper in arxiv_papers:
if arxiv_id := extract_arxiv_id(paper):
clean_arxiv_ids.append(arxiv_id)
model_to_arxiv_id[model.modelId] = clean_arxiv_ids
return model_to_arxiv_id
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def create_dataset_to_arxiv_id_dict():
datasets = get_datasets()
dataset_to_arxiv_id = {}
for dataset in datasets:
if arxiv_papers := check_for_arxiv_id(dataset):
clean_arxiv_ids = []
for paper in arxiv_papers:
if arxiv_id := extract_arxiv_id(paper):
clean_arxiv_ids.append(arxiv_id)
dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids
return dataset_to_arxiv_id
url = "lunarflu/ai-podcasts-and-talks-65119866353a60593bf99c58"
def group_collection_items(collection_slug: str):
collection = get_collection(collection_slug)
items = collection.items
return groupby(lambda x: f"{x.repoType}s", items)
def get_papers_for_collection(collection_slug: str):
dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict()
models_to_arxiv_id = create_model_to_arxiv_id_dict()
collection = group_collection_items(collection_slug)
collection_datasets = collection.get("datasets", None)
collection_models = collection.get("models", None)
dataset_papers = defaultdict(dict)
model_papers = defaultdict(dict)
if collection_datasets is not None:
for dataset in collection_datasets:
if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None):
data = {
"arxiv_ids": arxiv_ids,
"hub_paper_links": [
f"https://huggingface.co/papers/{arxiv_id}"
for arxiv_id in arxiv_ids
],
}
dataset_papers[dataset.item_id] = data
if collection_models is not None:
for model in collection.get("models", []):
if arxiv_ids := models_to_arxiv_id.get(model.item_id, None):
data = {
"arxiv_ids": arxiv_ids,
"hub_paper_links": [
f"https://huggingface.co/papers/{arxiv_id}"
for arxiv_id in arxiv_ids
],
}
model_papers[model.item_id] = data
return {"datasets": dataset_papers, "models": model_papers}
url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9"
gr.Interface(get_papers_for_collection, "text", "json").launch()