import pandas as pd
from typing import List, Dict, Optional
from constants import (
DATASET_ARXIV_SCAN_PAPERS,
DATASET_CONFERENCE_PAPERS,
DATASET_COMMUNITY_SCIENCE,
NEURIPS_ICO,
DATASET_PAPER_CENTRAL,
COLM_ICO,
DEFAULT_ICO,
MICCAI24ICO,
)
import gradio as gr
from utils import load_and_process
import numpy as np
class PaperCentral:
"""
A class to manage and process paper data for display in a Gradio Dataframe component.
"""
CONFERENCES = [
"ACL2023",
"ACL2024",
"COLING2024",
"CVPR2023",
"CVPR2024",
"ECCV2024",
"EMNLP2023",
"NAACL2023",
"NeurIPS2023",
"NeurIPS2023 D&B",
"COLM2024",
"MICCAI2024",
]
CONFERENCES_ICONS = {
"ACL2023": 'https://aclanthology.org/aclicon.ico',
"ACL2024": 'https://aclanthology.org/aclicon.ico',
"COLING2024": 'https://aclanthology.org/aclicon.ico',
"CVPR2023": "https://openaccess.thecvf.com/favicon.ico",
"CVPR2024": "https://openaccess.thecvf.com/favicon.ico",
"ECCV2024": "https://openaccess.thecvf.com/favicon.ico",
"EMNLP2023": 'https://aclanthology.org/aclicon.ico',
"NAACL2023": 'https://aclanthology.org/aclicon.ico',
"NeurIPS2023": NEURIPS_ICO,
"NeurIPS2023 D&B": NEURIPS_ICO,
"COLM2024": COLM_ICO,
"MICCAI2024": MICCAI24ICO,
}
# Class-level constants defining columns and their data types
COLUMNS_START_PAPER_PAGE: List[str] = [
'date',
'arxiv_id',
'paper_page',
'title',
]
COLUMNS_ORDER_PAPER_PAGE: List[str] = [
'date',
'arxiv_id',
'paper_page',
'num_models',
'num_datasets',
'num_spaces',
'upvotes',
'num_comments',
'github',
'conference_name',
'id',
'type',
'proceedings',
'title',
'authors',
]
DATATYPES: Dict[str, str] = {
'date': 'str',
'arxiv_id': 'markdown',
'paper_page': 'markdown',
'upvotes': 'str',
'num_comments': 'str',
'num_models': 'markdown',
'num_datasets': 'markdown',
'num_spaces': 'markdown',
'github': 'markdown',
'title': 'str',
'proceedings': 'markdown',
'conference_name': 'str',
'id': 'str',
'type': 'str',
'authors': 'str',
}
# Mapping for renaming columns for display purposes
COLUMN_RENAME_MAP: Dict[str, str] = {
'num_models': 'models',
'num_spaces': 'spaces',
'num_datasets': 'datasets',
'conference_name': 'venue',
}
def __init__(self):
"""
Initialize the PaperCentral class by loading and processing the datasets.
"""
self.df_raw: pd.DataFrame = self.get_df()
self.df_prettified: pd.DataFrame = self.prettify(self.df_raw)
@staticmethod
def get_columns_order(columns: List[str]) -> List[str]:
"""
Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE.
Args:
columns (List[str]): List of column names to order.
Returns:
List[str]: Ordered list of column names.
"""
return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns]
@staticmethod
def get_columns_datatypes(columns: List[str]) -> List[str]:
"""
Get data types for the specified columns.
Args:
columns (List[str]): List of column names.
Returns:
List[str]: List of data types corresponding to the columns.
"""
return [PaperCentral.DATATYPES[c] for c in columns]
@staticmethod
def get_df() -> pd.DataFrame:
"""
Load and merge datasets to create the raw DataFrame.
Returns:
pd.DataFrame: The merged and processed DataFrame.
"""
# Load datasets
paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[
['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models',
'num_datasets', 'num_spaces', 'id', 'proceedings', 'type',
'conference_name', 'title', 'paper_page', 'authors']
]
return paper_central_df
@staticmethod
def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame:
"""
Format the date column in the DataFrame to 'YYYY-MM-DD'.
Args:
df (pd.DataFrame): The DataFrame to format.
date_column (str): The name of the date column.
Returns:
pd.DataFrame: The DataFrame with the formatted date column.
"""
df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d')
return df
@staticmethod
def prettify(df: pd.DataFrame) -> pd.DataFrame:
"""
Prettify the DataFrame by adding markdown links and sorting.
Args:
df (pd.DataFrame): The DataFrame to prettify.
Returns:
pd.DataFrame: The prettified DataFrame.
"""
def update_row(row: pd.Series) -> pd.Series:
"""
Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns.
Args:
row (pd.Series): A row from the DataFrame.
Returns:
pd.Series: The updated row.
"""
# Process 'num_models' column
if (
'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"]
and float(row['num_models']) > 0
):
num_models = int(float(row['num_models']))
row['num_models'] = (
f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})"
)
if (
'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"]
and float(row['num_datasets']) > 0
):
num_datasets = int(float(row['num_datasets']))
row['num_datasets'] = (
f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})"
)
if (
'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"]
and float(row['num_spaces']) > 0
):
num_spaces = int(float(row['num_spaces']))
row['num_spaces'] = (
f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})"
)
if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']:
image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO)
style = "display:inline-block; vertical-align:middle; width: 16px; height:16px"
row['proceedings'] = (
f""
f"proc_page"
)
####
### This should be processed last :)
####
# Add markdown link to 'paper_page' if it exists
if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']:
row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})"
# Add image and link to 'arxiv_id' if it exists
if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']:
image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png"
style = "display:inline-block; vertical-align:middle;"
row['arxiv_id'] = (
f""
f"arxiv_page"
)
# Add image and link to 'arxiv_id' if it exists
if 'github' in row and pd.notna(row['github']) and row["github"]:
image_url = "https://github.githubassets.com/favicons/favicon.png"
style = "display:inline-block; vertical-align:middle;width:16px;"
row['github'] = (
f""
f"github"
)
return row
df = df.copy()
# Apply the update_row function to each row
prettified_df: pd.DataFrame = df.apply(update_row, axis=1)
return prettified_df
def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes.
Args:
df (pd.DataFrame): The DataFrame whose columns need to be renamed.
Returns:
pd.DataFrame: The DataFrame with renamed columns.
"""
return df.rename(columns=self.COLUMN_RENAME_MAP)
def filter(
self,
selected_date: Optional[str] = None,
cat_options: Optional[List[str]] = None,
hf_options: Optional[List[str]] = None,
conference_options: Optional[List[str]] = None,
author_search_input: Optional[str] = None,
title_search_input: Optional[str] = None,
) -> gr.update:
"""
Filter the DataFrame based on selected date and options, and prepare it for display.
Args:
selected_date (Optional[str]): The date to filter the DataFrame.
hf_options (Optional[List[str]]): List of options selected by the user.
conference_options (Optional[List[str]]): List of conference options selected by the user.
Returns:
gr.Update: An update object for the Gradio Dataframe component.
"""
filtered_df: pd.DataFrame = self.df_raw.copy()
# Start with the initial columns to display
columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy()
if title_search_input:
if 'title' not in columns_to_show:
columns_to_show.append('authors')
search_string = title_search_input.lower()
def title_match(title):
if isinstance(title, str):
# If authors_list is a single string
return search_string in title.lower()
else:
# Handle unexpected data types
return False
filtered_df = filtered_df[filtered_df['title'].apply(title_match)]
if author_search_input:
if 'authors' not in columns_to_show:
columns_to_show.append('authors')
search_string = author_search_input.lower()
def author_matches(authors_list):
# Check if authors_list is None or empty
if authors_list is None or len(authors_list) == 0:
return False
# Check if authors_list is an iterable (list, tuple, Series, or ndarray)
if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)):
return any(
isinstance(author, str) and search_string in author.lower()
for author in authors_list
)
elif isinstance(authors_list, str):
# If authors_list is a single string
return search_string in authors_list.lower()
else:
# Handle unexpected data types
return False
filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)]
if cat_options:
options = [o.replace(".*", "") for o in cat_options]
# Initialize filter series
conference_filter = pd.Series(False, index=filtered_df.index)
for option in options:
# Filter rows where 'conference_name' contains the conference string (case-insensitive)
conference_filter |= (
filtered_df['primary_category'].notna() &
filtered_df['primary_category'].str.contains(option, case=False)
)
filtered_df = filtered_df[conference_filter]
# Date
if selected_date and not conference_options:
selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d')
filtered_df = filtered_df[filtered_df['date'] == selected_date]
# HF options
if hf_options:
if "🤗 artifacts" in hf_options:
# Filter rows where 'paper_page' is not empty or NaN
filtered_df = filtered_df[
(filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna())
]
# Add 'upvotes' column if not already in columns_to_show
if 'upvotes' not in columns_to_show:
columns_to_show.append('upvotes')
# Add 'num_models' column if not already in columns_to_show
if 'num_models' not in columns_to_show:
columns_to_show.append('num_models')
if 'num_datasets' not in columns_to_show:
columns_to_show.append('num_datasets')
if 'num_spaces' not in columns_to_show:
columns_to_show.append('num_spaces')
if "datasets" in hf_options:
if 'num_datasets' not in columns_to_show:
columns_to_show.append('num_datasets')
filtered_df = filtered_df[filtered_df['num_datasets'] != 0]
if "models" in hf_options:
if 'num_models' not in columns_to_show:
columns_to_show.append('num_models')
filtered_df = filtered_df[filtered_df['num_models'] != 0]
if "spaces" in hf_options:
if 'num_spaces' not in columns_to_show:
columns_to_show.append('num_spaces')
filtered_df = filtered_df[filtered_df['num_spaces'] != 0]
if "github" in hf_options:
if 'github' not in columns_to_show:
columns_to_show.append('github')
filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())]
# Apply conference filtering
if conference_options:
columns_to_show.remove("date")
columns_to_show.remove("arxiv_id")
if 'conference_name' not in columns_to_show:
columns_to_show.append('conference_name')
if 'proceedings' not in columns_to_show:
columns_to_show.append('proceedings')
if 'type' not in columns_to_show:
columns_to_show.append('type')
if 'id' not in columns_to_show:
columns_to_show.append('id')
# If "In proceedings" is selected
if "In proceedings" in conference_options:
# Filter rows where 'conference_name' is not None, not NaN, and not empty
filtered_df = filtered_df[
filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "")
]
# For other conference options
other_conferences = [conf for conf in conference_options if conf != "In proceedings"]
if other_conferences:
# Initialize filter series
conference_filter = pd.Series(False, index=filtered_df.index)
for conference in other_conferences:
# Filter rows where 'conference_name' contains the conference string (case-insensitive)
conference_filter |= (
filtered_df['conference_name'].notna() &
(filtered_df['conference_name'].str.lower() == conference.lower())
)
filtered_df = filtered_df[conference_filter]
# Prettify the DataFrame
filtered_df = self.prettify(filtered_df)
# Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE
columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show]
# Select and reorder the columns
filtered_df = filtered_df[columns_in_order]
# Rename columns for display
filtered_df = self.rename_columns_for_display(filtered_df)
# Get the corresponding data types for the columns
new_datatypes: List[str] = [
PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns
]
# Sort rows to display entries with 'paper_page' first
if 'paper_page' in filtered_df.columns:
filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "")
filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True)
filtered_df.drop(columns='has_paper_page', inplace=True)
# Return an update object to modify the Dataframe component
return gr.update(value=filtered_df, datatype=new_datatypes)
def _get_original_column_name(self, display_column_name: str) -> str:
"""
Retrieve the original column name given a display column name.
Args:
display_column_name (str): The display name of the column.
Returns:
str: The original name of the column.
"""
inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()}
return inverse_map.get(display_column_name, display_column_name)