Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import dataclasses | |
import datetime | |
import operator | |
import datasets | |
import pandas as pd | |
import tqdm.auto | |
from huggingface_hub import HfApi | |
from ragatouille import RAGPretrainedModel | |
api = HfApi() | |
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/daily-papers-abstract-index/" | |
api.snapshot_download( | |
repo_id="hysts-bot-data/daily-papers-abstract-index", | |
repo_type="dataset", | |
local_dir=INDEX_DIR_PATH, | |
) | |
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH) | |
# Run once to initialize the retriever | |
ABSTRACT_RETRIEVER.search("LLM") | |
class PaperInfo: | |
date: str | |
arxiv_id: str | |
github: str | |
title: str | |
paper_page: str | |
upvotes: int | |
num_comments: int | |
def get_df() -> pd.DataFrame: | |
df = pd.merge( | |
left=datasets.load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(), | |
right=datasets.load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(), | |
on="arxiv_id", | |
) | |
df = df[::-1].reset_index(drop=True) | |
paper_info = [] | |
for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)): | |
info = PaperInfo( | |
**row, | |
paper_page=f"https://huggingface.co/papers/{row.arxiv_id}", | |
) | |
paper_info.append(info) | |
return pd.DataFrame([dataclasses.asdict(info) for info in paper_info]) | |
class Prettifier: | |
def get_github_link(link: str) -> str: | |
if not link: | |
return "" | |
return Prettifier.create_link("github", link) | |
def create_link(text: str, url: str) -> str: | |
return f'<a href="{url}" target="_blank">{text}</a>' | |
def to_div(text: str | None, category_name: str) -> str: | |
if text is None: | |
text = "" | |
class_name = f"{category_name}-{text.lower()}" | |
return f'<div class="{class_name}">{text}</div>' | |
def __call__(self, df: pd.DataFrame) -> pd.DataFrame: | |
new_rows = [] | |
for _, row in df.iterrows(): | |
new_row = { | |
"date": Prettifier.create_link(row.date, f"https://huggingface.co/papers?date={row.date}"), | |
"paper_page": Prettifier.create_link(row.arxiv_id, row.paper_page), | |
"title": row["title"], | |
"github": self.get_github_link(row.github), | |
"π": row["upvotes"], | |
"π¬": row["num_comments"], | |
} | |
new_rows.append(new_row) | |
return pd.DataFrame(new_rows) | |
class PaperList: | |
COLUMN_INFO = [ | |
["date", "markdown"], | |
["paper_page", "markdown"], | |
["title", "str"], | |
["github", "markdown"], | |
["π", "number"], | |
["π¬", "number"], | |
] | |
def __init__(self, df: pd.DataFrame): | |
self.df_raw = df | |
self._prettifier = Prettifier() | |
self.df_prettified = self._prettifier(df).loc[:, self.column_names] | |
def column_names(self): | |
return list(map(operator.itemgetter(0), self.COLUMN_INFO)) | |
def column_datatype(self): | |
return list(map(operator.itemgetter(1), self.COLUMN_INFO)) | |
def search( | |
self, | |
start_date: datetime.datetime, | |
end_date: datetime.datetime, | |
title_search_query: str, | |
abstract_search_query: str, | |
max_num_to_retrieve: int, | |
) -> pd.DataFrame: | |
df = self.df_raw.copy() | |
df["date"] = pd.to_datetime(df["date"]) | |
# Filter by date | |
df = df[(df["date"] >= start_date) & (df["date"] <= end_date)] | |
df["date"] = df["date"].dt.strftime("%Y-%m-%d") | |
# Filter by title | |
df = df[df["title"].str.contains(title_search_query, case=False)] | |
# Filter by abstract | |
if abstract_search_query: | |
results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve) | |
remaining_ids = set(df["arxiv_id"]) | |
found_id_set = set() | |
found_ids = [] | |
for x in results: | |
arxiv_id = x["document_id"] | |
if arxiv_id not in remaining_ids: | |
continue | |
if arxiv_id in found_id_set: | |
continue | |
found_id_set.add(arxiv_id) | |
found_ids.append(arxiv_id) | |
df = df[df["arxiv_id"].isin(found_ids)].set_index("arxiv_id").reindex(index=found_ids).reset_index() | |
df_prettified = self._prettifier(df).loc[:, self.column_names] | |
return df_prettified | |