davanstrien's picture
davanstrien HF staff
draft app
ad38c8f
raw
history blame
No virus
3.82 kB
import arxiv
import gradio as gr
import pandas as pd
from apscheduler.schedulers.background import BackgroundScheduler
from cachetools import TTLCache, cached
from setfit import SetFitModel
from tqdm.auto import tqdm
CACHE_TIME = 60 * 60 * 12
MAX_RESULTS = 30_000
@cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME))
def get_arxiv_result():
search = arxiv.Search(
query="ti:dataset AND abs:machine learning",
max_results=MAX_RESULTS,
sort_by=arxiv.SortCriterion.SubmittedDate,
)
return [
{
"title": result.title,
"abstract": result.summary,
"url": result.entry_id,
"category": result.primary_category,
"updated": result.updated,
}
for result in tqdm(search.results(), total=MAX_RESULTS)
]
def load_model():
return SetFitModel.from_pretrained("librarian-bots/is_new_dataset_teacher_model")
def format_row_for_model(row):
return f"TITLE: {row['title']} \n\nABSTRACT: {row['abstract']}"
int2label = {0: "new_dataset", 1: "not_new_dataset"}
def get_predictions(data: list[dict], model=None, batch_size=32):
if model is None:
model = load_model()
predictions = []
for i in tqdm(range(0, len(data), batch_size)):
batch = data[i : i + batch_size]
text_inputs = [format_row_for_model(row) for row in batch]
batch_predictions = model.predict_proba(text_inputs)
for j, row in enumerate(batch):
prediction = batch_predictions[j]
row["prediction"] = int2label[int(prediction.argmax())]
row["probability"] = float(prediction.max())
predictions.append(row)
return predictions
def create_markdown(row):
title = row["title"]
abstract = row["abstract"]
arxiv_id = row["arxiv_id"]
hub_paper_url = f"https://huggingface.co/papers/{arxiv_id}"
updated = row["updated"]
updated = updated.strftime("%Y-%m-%d")
broad_category = row["broad_category"]
category = row["category"]
return f""" <h1> {title} </h1> updated: {updated}
| category: {broad_category} | subcategory: {category} |
\n\n{abstract}
\n\n [Hugging Face Papers page]({hub_paper_url})
"""
@cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
def prepare_data():
print("Downloading arxiv results...")
arxiv_results = get_arxiv_result()
print("loading model...")
model = load_model()
print("Making predictions...")
predictions = get_predictions(arxiv_results, model=model)
df = pd.DataFrame(predictions)
df.loc[:, "arxiv_id"] = df["url"].str.extract(r"(\d+\.\d+)")
df.loc[:, "broad_category"] = df["category"].str.split(".").str[0]
df.loc[:, "markdown"] = df.apply(create_markdown, axis=1)
return df
all_possible_arxiv_categories = prepare_data().category.unique().tolist()
broad_categories = prepare_data().broad_category.unique().tolist()
def create_markdown_summary(categories=broad_categories, all_categories=None):
df = prepare_data()
if categories is not None:
df = df[df["broad_category"].isin(categories)]
return "\n\n".join(df["markdown"].tolist())
scheduler = BackgroundScheduler()
scheduler.add_job(prepare_data, "cron", hour=3, minute=30)
scheduler.start()
with gr.Blocks() as demo:
gr.Markdown("## New Datasets in Machine Learning")
gr.Markdown(
"This Space attempts to show new papers on arXiv that are *likely* to be papers"
" introducing new datasets. \n\n"
)
broad_categories = gr.Dropdown(
choices=broad_categories,
label="Categories",
multiselect=True,
value=broad_categories,
)
results = gr.Markdown(create_markdown_summary())
broad_categories.change(create_markdown_summary, broad_categories, results)
demo.launch()