paper-central / df /PaperCentral.py
jbdel
chat paper
dbdfc66
raw
history blame
22.1 kB
import pandas as pd
from typing import List, Dict, Optional
from constants import (
DATASET_ARXIV_SCAN_PAPERS,
DATASET_CONFERENCE_PAPERS,
DATASET_COMMUNITY_SCIENCE,
NEURIPS_ICO,
DATASET_PAPER_CENTRAL,
COLM_ICO,
DEFAULT_ICO,
MICCAI24ICO,
CORL_ICO,
)
import gradio as gr
from utils import load_and_process
import numpy as np
from datetime import datetime, timedelta
import re
class PaperCentral:
"""
A class to manage and process paper data for display in a Gradio Dataframe component.
"""
CONFERENCES_ICONS = {
"NeurIPS2024 D&B": NEURIPS_ICO,
"NeurIPS2024": NEURIPS_ICO,
"EMNLP2024": 'https://aclanthology.org/aclicon.ico',
"CoRL2024": CORL_ICO,
"ACMMM2024": "https://2024.acmmm.org/favicon.ico",
"MICCAI2024": MICCAI24ICO,
"COLM2024": COLM_ICO,
"COLING2024": 'https://aclanthology.org/aclicon.ico',
"CVPR2024": "https://openaccess.thecvf.com/favicon.ico",
"ACL2024": 'https://aclanthology.org/aclicon.ico',
"ACL2023": 'https://aclanthology.org/aclicon.ico',
"CVPR2023": "https://openaccess.thecvf.com/favicon.ico",
"ECCV2024": "https://openaccess.thecvf.com/favicon.ico",
"EMNLP2023": 'https://aclanthology.org/aclicon.ico',
"NAACL2023": 'https://aclanthology.org/aclicon.ico',
"NeurIPS2023": NEURIPS_ICO,
"NeurIPS2023 D&B": NEURIPS_ICO,
}
CONFERENCES = list(CONFERENCES_ICONS.keys())
# Class-level constants defining columns and their data types
COLUMNS_START_PAPER_PAGE: List[str] = [
'date',
'arxiv_id',
'paper_page',
'title',
]
COLUMNS_ORDER_PAPER_PAGE: List[str] = [
'chat_with_paper',
'date',
'arxiv_id',
'paper_page',
'num_models',
'num_datasets',
'num_spaces',
'upvotes',
'num_comments',
'github',
'github_stars',
'project_page',
'conference_name',
'id',
'type',
'proceedings',
'title',
'authors',
]
DATATYPES: Dict[str, str] = {
'date': 'str',
'arxiv_id': 'markdown',
'paper_page': 'markdown',
'upvotes': 'number',
'num_comments': 'number',
'num_models': 'markdown',
'num_datasets': 'markdown',
'num_spaces': 'markdown',
'github': 'markdown',
'title': 'str',
'proceedings': 'markdown',
'conference_name': 'str',
'id': 'str',
'type': 'str',
'authors': 'str',
'github_stars': 'number',
'project_page': 'markdown',
'chat_with_paper': 'markdown',
}
# Mapping for renaming columns for display purposes
COLUMN_RENAME_MAP: Dict[str, str] = {
'num_models': 'models',
'num_spaces': 'spaces',
'num_datasets': 'datasets',
'github': 'GitHub',
'github_stars': 'GitHub⭐',
'num_comments': '💬',
'upvotes': '👍',
'chat_with_paper': 'Chat',
}
def __init__(self):
"""
Initialize the PaperCentral class by loading and processing the datasets.
"""
self.df_raw: pd.DataFrame = self.get_df()
self.df_prettified: pd.DataFrame = self.prettify(self.df_raw)
@staticmethod
def get_columns_order(columns: List[str]) -> List[str]:
"""
Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE.
Args:
columns (List[str]): List of column names to order.
Returns:
List[str]: Ordered list of column names.
"""
return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns]
@staticmethod
def get_columns_datatypes(columns: List[str]) -> List[str]:
"""
Get data types for the specified columns.
Args:
columns (List[str]): List of column names.
Returns:
List[str]: List of data types corresponding to the columns.
"""
return [PaperCentral.DATATYPES[c] for c in columns]
@staticmethod
def get_df() -> pd.DataFrame:
"""
Load and merge datasets to create the raw DataFrame.
Returns:
pd.DataFrame: The merged and processed DataFrame.
"""
# Load datasets
paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[
['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models',
'num_datasets', 'num_spaces', 'id', 'proceedings', 'type',
'conference_name', 'title', 'paper_page', 'authors', 'github_stars', 'project_page']
]
# If arxiv published_date is weekend, switch to Monday
def adjust_date(dt):
if dt.weekday() == 5: # Saturday
return dt + pd.Timedelta(days=2)
elif dt.weekday() == 6: # Sunday
return dt + pd.Timedelta(days=1)
else:
return dt
# Convert 'date' column to datetime
paper_central_df['date'] = pd.to_datetime(paper_central_df['date'], format='%Y-%m-%d')
paper_central_df['date'] = paper_central_df['date'].apply(adjust_date)
paper_central_df['date'] = paper_central_df['date'].dt.strftime('%Y-%m-%d')
return paper_central_df
@staticmethod
def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame:
"""
Format the date column in the DataFrame to 'YYYY-MM-DD'.
Args:
df (pd.DataFrame): The DataFrame to format.
date_column (str): The name of the date column.
Returns:
pd.DataFrame: The DataFrame with the formatted date column.
"""
df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d')
return df
@staticmethod
def prettify(df: pd.DataFrame) -> pd.DataFrame:
"""
Prettify the DataFrame by adding markdown links and sorting.
Args:
df (pd.DataFrame): The DataFrame to prettify.
Returns:
pd.DataFrame: The prettified DataFrame.
"""
def update_row(row: pd.Series) -> pd.Series:
"""
Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns.
Args:
row (pd.Series): A row from the DataFrame.
Returns:
pd.Series: The updated row.
"""
# Process 'num_models' column
if (
'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"]
and float(row['num_models']) > 0
):
num_models = int(float(row['num_models']))
row['num_models'] = (
f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})"
)
if (
'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"]
and float(row['num_datasets']) > 0
):
num_datasets = int(float(row['num_datasets']))
row['num_datasets'] = (
f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})"
)
if (
'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"]
and float(row['num_spaces']) > 0
):
num_spaces = int(float(row['num_spaces']))
row['num_spaces'] = (
f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})"
)
if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']:
image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO)
style = "display:inline-block; vertical-align:middle; width: 16px; height:16px"
row['proceedings'] = (
f"<img src='{image_url}' style='{style}'/>"
f"<a href='{row['proceedings']}'>proc_page</a>"
)
####
### This should be processed last :)
####
# Add markdown link to 'paper_page' if it exists
if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']:
row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})"
# Add image and link to 'arxiv_id' if it exists
if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']:
image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png"
style = "display:inline-block; vertical-align:middle;"
row['arxiv_id'] = (
f"<img src='{image_url}' style='{style}'/>"
f"<a href='https://arxiv.org/abs/{row['arxiv_id']}'>arxiv_page</a>"
)
# Add image and link to 'arxiv_id' if it exists
if 'github' in row and pd.notna(row['github']) and row["github"]:
image_url = "https://github.githubassets.com/favicons/favicon.png"
style = "display:inline-block; vertical-align:middle;width:16px;"
row['github'] = (
f"<img src='{image_url}' style='{style}'/>"
f"<a href='{row['github']}'>github</a>"
)
if 'project_page' in row and pd.notna(row['project_page']) and row["project_page"]:
row['project_page'] = (
f"<a href='{row['project_page']}'>{row['project_page']}</a>"
)
return row
df = df.copy()
# Apply the update_row function to each row
prettified_df: pd.DataFrame = df.apply(update_row, axis=1)
return prettified_df
def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes.
Args:
df (pd.DataFrame): The DataFrame whose columns need to be renamed.
Returns:
pd.DataFrame: The DataFrame with renamed columns.
"""
return df.rename(columns=self.COLUMN_RENAME_MAP)
def filter(
self,
selected_date: Optional[str] = None,
cat_options: Optional[List[str]] = None,
hf_options: Optional[List[str]] = None,
conference_options: Optional[List[str]] = None,
author_search_input: Optional[str] = None,
title_search_input: Optional[str] = None,
date_range_option: Optional[str] = None,
) -> gr.update:
"""
Filter the DataFrame based on selected date and options, and prepare it for display.
"""
filtered_df: pd.DataFrame = self.df_raw.copy()
# Start with the initial columns to display
columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy()
# Handle title search
if title_search_input:
if 'title' not in columns_to_show:
columns_to_show.append('authors')
search_string = title_search_input.lower()
def title_match(title):
if isinstance(title, str):
return search_string in title.lower()
else:
return False
filtered_df = filtered_df[filtered_df['title'].apply(title_match)]
# Handle author search
if author_search_input:
if 'authors' not in columns_to_show:
columns_to_show.append('authors')
search_string = author_search_input.lower()
def author_matches(authors_list):
if authors_list is None or len(authors_list) == 0:
return False
if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)):
return any(
isinstance(author, str) and search_string in author.lower()
for author in authors_list
)
elif isinstance(authors_list, str):
return search_string in authors_list.lower()
else:
return False
filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)]
# Handle category options
if cat_options:
if "(ALL)" in cat_options:
# If "(ALL)" is selected, include all categories without filtering
pass # No action needed, include all categories
else:
# Proceed with filtering based on selected categories
options = [o.replace(".*", "") for o in cat_options]
conference_filter = pd.Series(False, index=filtered_df.index)
for option in options:
conference_filter |= (
filtered_df['primary_category'].notna() &
filtered_df['primary_category'].str.contains(option, case=False)
)
filtered_df = filtered_df[conference_filter]
# Handle date filtering
if not conference_options:
if date_range_option:
today = datetime.now()
if date_range_option == "This week":
start_date = (today - timedelta(days=7)).strftime('%Y-%m-%d')
end_date = today.strftime('%Y-%m-%d')
elif date_range_option == "This month":
start_date = (today - timedelta(days=30)).strftime('%Y-%m-%d')
end_date = today.strftime('%Y-%m-%d')
elif date_range_option == "This year":
start_date = (today - timedelta(days=365)).strftime('%Y-%m-%d')
end_date = today.strftime('%Y-%m-%d')
elif date_range_option == "All time":
start_date = None
end_date = None
else:
start_date = None
end_date = None
if start_date and end_date:
filtered_df = filtered_df[
(filtered_df['date'] >= start_date) & (filtered_df['date'] <= end_date)
]
else:
pass # No date filtering for "All time"
elif selected_date:
selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d')
filtered_df = filtered_df[filtered_df['date'] == selected_date]
# Handle Hugging Face options
if hf_options:
# Convert columns to numeric, handling non-numeric values
filtered_df['num_datasets'] = pd.to_numeric(filtered_df['num_datasets'], errors='coerce').fillna(0).astype(
int)
filtered_df['num_models'] = pd.to_numeric(filtered_df['num_models'], errors='coerce').fillna(0).astype(int)
filtered_df['num_spaces'] = pd.to_numeric(filtered_df['num_spaces'], errors='coerce').fillna(0).astype(int)
if "🤗 artifacts" in hf_options:
filtered_df = filtered_df[
(filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna())
]
if 'upvotes' not in columns_to_show:
columns_to_show.append('upvotes')
if 'num_comments' not in columns_to_show:
columns_to_show.append('num_comments')
if 'num_models' not in columns_to_show:
columns_to_show.append('num_models')
if 'num_datasets' not in columns_to_show:
columns_to_show.append('num_datasets')
if 'num_spaces' not in columns_to_show:
columns_to_show.append('num_spaces')
filtered_df = filtered_df[
(filtered_df['num_datasets'] > 0) |
(filtered_df['num_models'] > 0) |
(filtered_df['num_spaces'] > 0)
]
if "datasets" in hf_options:
if 'num_datasets' not in columns_to_show:
columns_to_show.append('num_datasets')
filtered_df = filtered_df[filtered_df['num_datasets'] != 0]
if "models" in hf_options:
if 'num_models' not in columns_to_show:
columns_to_show.append('num_models')
filtered_df = filtered_df[filtered_df['num_models'] != 0]
if "spaces" in hf_options:
if 'num_spaces' not in columns_to_show:
columns_to_show.append('num_spaces')
filtered_df = filtered_df[filtered_df['num_spaces'] != 0]
if "github" in hf_options:
if 'github' not in columns_to_show:
columns_to_show.append('github')
columns_to_show.append('github_stars')
filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())]
if "project page" in hf_options:
if 'project_page' not in columns_to_show:
columns_to_show.append('project_page')
filtered_df = filtered_df[(filtered_df['project_page'] != "") & (filtered_df['project_page'].notnull())]
# create chat link
def create_chat_link(row):
if pd.notna(row["paper_page"]) and row["paper_page"] != "":
paper_id = row["paper_page"]
return f'<a' \
f' action_id="chat-with-paper" paper_id="{paper_id}" paper_from="paper_page"' \
f' id="custom_button">✨ Chat with paper</a>'
return ""
filtered_df['chat_with_paper'] = filtered_df.apply(create_chat_link, axis=1)
if 'chat_with_paper' not in columns_to_show:
columns_to_show.append('chat_with_paper')
# Apply conference filtering
if conference_options:
columns_to_show = [col for col in columns_to_show if col not in ["date", "arxiv_id"]]
if 'conference_name' not in columns_to_show:
columns_to_show.append('conference_name')
if 'proceedings' not in columns_to_show:
columns_to_show.append('proceedings')
if 'type' not in columns_to_show:
columns_to_show.append('type')
if 'id' not in columns_to_show:
columns_to_show.append('id')
if "ALL" in conference_options:
filtered_df = filtered_df[
filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "")
]
other_conferences = [conf for conf in conference_options if conf != "ALL"]
if other_conferences:
conference_filter = pd.Series(False, index=filtered_df.index)
for conference in other_conferences:
conference_filter |= (
filtered_df['conference_name'].notna() &
(filtered_df['conference_name'].str.lower() == conference.lower())
)
filtered_df = filtered_df[conference_filter]
# conference chat with paper
if any(conf in ["NeurIPS2024 D&B", "NeurIPS2024"] for conf in conference_options):
def create_chat_neurips_link(row):
neurips_id = re.search(r'id=([^&]+)', row["proceedings"])
if neurips_id:
neurips_id = neurips_id.group(1)
return f'<a' \
f' action_id="chat-with-paper" paper_id={neurips_id} paper_from="neurips"' \
f' id="custom_button">✨ Chat with paper</a>'
else:
return ""
# Add the "chat_with_paper" column
filtered_df['chat_with_paper'] = filtered_df.apply(create_chat_neurips_link, axis=1)
if 'chat_with_paper' not in columns_to_show:
columns_to_show.append('chat_with_paper')
# Prettify the DataFrame
filtered_df = self.prettify(filtered_df)
# Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE
columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show]
filtered_df = filtered_df[columns_in_order]
# Rename columns for display
filtered_df = self.rename_columns_for_display(filtered_df)
# Get the corresponding data types for the columns
new_datatypes: List[str] = [
PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns
]
# Sort rows to display entries with 'paper_page' first
if 'paper_page' in filtered_df.columns:
filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "")
filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True)
filtered_df.drop(columns='has_paper_page', inplace=True)
# Return an update object to modify the Dataframe component
return gr.update(value=filtered_df, datatype=new_datatypes)
def _get_original_column_name(self, display_column_name: str) -> str:
"""
Retrieve the original column name given a display column name.
Args:
display_column_name (str): The display name of the column.
Returns:
str: The original name of the column.
"""
inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()}
return inverse_map.get(display_column_name, display_column_name)