Spaces:
Running
Running
import pandas as pd | |
from typing import List, Dict, Optional | |
from constants import ( | |
DATASET_ARXIV_SCAN_PAPERS, | |
DATASET_CONFERENCE_PAPERS, | |
DATASET_COMMUNITY_SCIENCE, | |
NEURIPS_ICO, | |
DATASET_PAPER_CENTRAL, | |
) | |
import gradio as gr | |
from utils import load_and_process | |
import numpy as np | |
class PaperCentral: | |
""" | |
A class to manage and process paper data for display in a Gradio Dataframe component. | |
""" | |
CONFERENCES = [ | |
"ACL2023", | |
"ACL2024", | |
"COLING2024", | |
"CVPR2023", | |
"CVPR2024", | |
"ECCV2024", | |
"EMNLP2023", | |
"NAACL2023", | |
"NeurIPS2023", | |
"NeurIPS2023 D&B", | |
] | |
CONFERENCES_ICONS = { | |
"ACL2023": 'https://aclanthology.org/aclicon.ico', | |
"ACL2024": 'https://aclanthology.org/aclicon.ico', | |
"COLING2024": 'https://aclanthology.org/aclicon.ico', | |
"CVPR2023": "https://openaccess.thecvf.com/favicon.ico", | |
"CVPR2024": "https://openaccess.thecvf.com/favicon.ico", | |
"ECCV2024": "https://openaccess.thecvf.com/favicon.ico", | |
"EMNLP2023": 'https://aclanthology.org/aclicon.ico', | |
"NAACL2023": 'https://aclanthology.org/aclicon.ico', | |
"NeurIPS2023": NEURIPS_ICO, | |
"NeurIPS2023 D&B": NEURIPS_ICO, | |
} | |
# Class-level constants defining columns and their data types | |
COLUMNS_START_PAPER_PAGE: List[str] = [ | |
'date', | |
'arxiv_id', | |
'paper_page', | |
'title', | |
] | |
COLUMNS_ORDER_PAPER_PAGE: List[str] = [ | |
'date', | |
'arxiv_id', | |
'paper_page', | |
'num_models', | |
'num_datasets', | |
'num_spaces', | |
'upvotes', | |
'num_comments', | |
'github', | |
'conference_name', | |
'id', | |
'type', | |
'proceedings', | |
'title', | |
'authors', | |
] | |
DATATYPES: Dict[str, str] = { | |
'date': 'str', | |
'arxiv_id': 'markdown', | |
'paper_page': 'markdown', | |
'upvotes': 'str', | |
'num_comments': 'str', | |
'num_models': 'markdown', | |
'num_datasets': 'markdown', | |
'num_spaces': 'markdown', | |
'github': 'markdown', | |
'title': 'str', | |
'proceedings': 'markdown', | |
'conference_name': 'str', | |
'id': 'str', | |
'type': 'str', | |
'authors': 'str', | |
} | |
# Mapping for renaming columns for display purposes | |
COLUMN_RENAME_MAP: Dict[str, str] = { | |
'num_models': 'models', | |
'num_spaces': 'spaces', | |
'num_datasets': 'datasets', | |
'conference_name': 'venue', | |
} | |
def __init__(self): | |
""" | |
Initialize the PaperCentral class by loading and processing the datasets. | |
""" | |
self.df_raw: pd.DataFrame = self.get_df() | |
self.df_prettified: pd.DataFrame = self.prettify(self.df_raw) | |
def get_columns_order(columns: List[str]) -> List[str]: | |
""" | |
Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE. | |
Args: | |
columns (List[str]): List of column names to order. | |
Returns: | |
List[str]: Ordered list of column names. | |
""" | |
return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns] | |
def get_columns_datatypes(columns: List[str]) -> List[str]: | |
""" | |
Get data types for the specified columns. | |
Args: | |
columns (List[str]): List of column names. | |
Returns: | |
List[str]: List of data types corresponding to the columns. | |
""" | |
return [PaperCentral.DATATYPES[c] for c in columns] | |
def get_df() -> pd.DataFrame: | |
""" | |
Load and merge datasets to create the raw DataFrame. | |
Returns: | |
pd.DataFrame: The merged and processed DataFrame. | |
""" | |
# Load datasets | |
paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[ | |
['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models', | |
'num_datasets', 'num_spaces', 'id', 'proceedings', 'type', | |
'conference_name', 'title', 'paper_page', 'authors'] | |
] | |
return paper_central_df | |
def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame: | |
""" | |
Format the date column in the DataFrame to 'YYYY-MM-DD'. | |
Args: | |
df (pd.DataFrame): The DataFrame to format. | |
date_column (str): The name of the date column. | |
Returns: | |
pd.DataFrame: The DataFrame with the formatted date column. | |
""" | |
df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d') | |
return df | |
def prettify(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Prettify the DataFrame by adding markdown links and sorting. | |
Args: | |
df (pd.DataFrame): The DataFrame to prettify. | |
Returns: | |
pd.DataFrame: The prettified DataFrame. | |
""" | |
def update_row(row: pd.Series) -> pd.Series: | |
""" | |
Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns. | |
Args: | |
row (pd.Series): A row from the DataFrame. | |
Returns: | |
pd.Series: The updated row. | |
""" | |
# Process 'num_models' column | |
if ( | |
'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"] | |
and float(row['num_models']) > 0 | |
): | |
num_models = int(float(row['num_models'])) | |
row['num_models'] = ( | |
f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})" | |
) | |
if ( | |
'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"] | |
and float(row['num_datasets']) > 0 | |
): | |
num_datasets = int(float(row['num_datasets'])) | |
row['num_datasets'] = ( | |
f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})" | |
) | |
if ( | |
'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"] | |
and float(row['num_spaces']) > 0 | |
): | |
num_spaces = int(float(row['num_spaces'])) | |
row['num_spaces'] = ( | |
f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})" | |
) | |
if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']: | |
image_url = PaperCentral.CONFERENCES_ICONS[row["conference_name"]] | |
style = "display:inline-block; vertical-align:middle; width: 16px; height:16px" | |
row['proceedings'] = ( | |
f"<img src='{image_url}' style='{style}'/>" | |
f"<a href='{row['proceedings']}'>proc_page</a>" | |
) | |
#### | |
### This should be processed last :) | |
#### | |
# Add markdown link to 'paper_page' if it exists | |
if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']: | |
row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})" | |
# Add image and link to 'arxiv_id' if it exists | |
if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']: | |
image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png" | |
style = "display:inline-block; vertical-align:middle;" | |
row['arxiv_id'] = ( | |
f"<img src='{image_url}' style='{style}'/>" | |
f"<a href='https://arxiv.org/abs/{row['arxiv_id']}'>arxiv_page</a>" | |
) | |
# Add image and link to 'arxiv_id' if it exists | |
if 'github' in row and pd.notna(row['github']) and row["github"]: | |
image_url = "https://github.githubassets.com/favicons/favicon.png" | |
style = "display:inline-block; vertical-align:middle;width:16px;" | |
row['github'] = ( | |
f"<img src='{image_url}' style='{style}'/>" | |
f"<a href='{row['github']}'>github</a>" | |
) | |
return row | |
df = df.copy() | |
# Sort rows to display entries with 'paper_page' first | |
if 'paper_page' in df.columns: | |
df['has_paper_page'] = df['paper_page'].notna() | |
df.sort_values(by='has_paper_page', ascending=False, inplace=True) | |
df.drop(columns='has_paper_page', inplace=True) | |
# Apply the update_row function to each row | |
prettified_df: pd.DataFrame = df.apply(update_row, axis=1) | |
return prettified_df | |
def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes. | |
Args: | |
df (pd.DataFrame): The DataFrame whose columns need to be renamed. | |
Returns: | |
pd.DataFrame: The DataFrame with renamed columns. | |
""" | |
return df.rename(columns=self.COLUMN_RENAME_MAP) | |
def filter( | |
self, | |
selected_date: Optional[str] = None, | |
cat_options: Optional[List[str]] = None, | |
hf_options: Optional[List[str]] = None, | |
conference_options: Optional[List[str]] = None, | |
author_search_input: Optional[str] = None, | |
title_search_input: Optional[str] = None, | |
) -> gr.update: | |
""" | |
Filter the DataFrame based on selected date and options, and prepare it for display. | |
Args: | |
selected_date (Optional[str]): The date to filter the DataFrame. | |
hf_options (Optional[List[str]]): List of options selected by the user. | |
conference_options (Optional[List[str]]): List of conference options selected by the user. | |
Returns: | |
gr.Update: An update object for the Gradio Dataframe component. | |
""" | |
filtered_df: pd.DataFrame = self.df_raw.copy() | |
# Start with the initial columns to display | |
columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy() | |
if title_search_input: | |
if 'title' not in columns_to_show: | |
columns_to_show.append('authors') | |
search_string = title_search_input.lower() | |
def title_match(title): | |
if isinstance(title, str): | |
# If authors_list is a single string | |
return search_string in title.lower() | |
else: | |
# Handle unexpected data types | |
return False | |
filtered_df = filtered_df[filtered_df['title'].apply(title_match)] | |
if author_search_input: | |
if 'authors' not in columns_to_show: | |
columns_to_show.append('authors') | |
search_string = author_search_input.lower() | |
def author_matches(authors_list): | |
# Check if authors_list is None or empty | |
if authors_list is None or len(authors_list) == 0: | |
return False | |
# Check if authors_list is an iterable (list, tuple, Series, or ndarray) | |
if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)): | |
return any( | |
isinstance(author, str) and search_string in author.lower() | |
for author in authors_list | |
) | |
elif isinstance(authors_list, str): | |
# If authors_list is a single string | |
return search_string in authors_list.lower() | |
else: | |
# Handle unexpected data types | |
return False | |
filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)] | |
if cat_options: | |
options = [o.replace(".*", "") for o in cat_options] | |
# Initialize filter series | |
conference_filter = pd.Series(False, index=filtered_df.index) | |
for option in options: | |
# Filter rows where 'conference_name' contains the conference string (case-insensitive) | |
conference_filter |= ( | |
filtered_df['primary_category'].notna() & | |
filtered_df['primary_category'].str.contains(option, case=False) | |
) | |
filtered_df = filtered_df[conference_filter] | |
# Date | |
if selected_date and not conference_options: | |
selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d') | |
filtered_df = filtered_df[filtered_df['date'] == selected_date] | |
# HF options | |
if hf_options: | |
if "🤗 paper-page" in hf_options: | |
# Filter rows where 'paper_page' is not empty or NaN | |
filtered_df = filtered_df[ | |
(filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna()) | |
] | |
# Add 'upvotes' column if not already in columns_to_show | |
if 'upvotes' not in columns_to_show: | |
columns_to_show.append('upvotes') | |
# Add 'num_models' column if not already in columns_to_show | |
if 'num_models' not in columns_to_show: | |
columns_to_show.append('num_models') | |
if 'num_datasets' not in columns_to_show: | |
columns_to_show.append('num_datasets') | |
if 'num_spaces' not in columns_to_show: | |
columns_to_show.append('num_spaces') | |
if "datasets" in hf_options: | |
if 'num_datasets' not in columns_to_show: | |
columns_to_show.append('num_datasets') | |
filtered_df = filtered_df[filtered_df['num_datasets'] != 0] | |
if "models" in hf_options: | |
if 'num_models' not in columns_to_show: | |
columns_to_show.append('num_models') | |
filtered_df = filtered_df[filtered_df['num_models'] != 0] | |
if "spaces" in hf_options: | |
if 'num_spaces' not in columns_to_show: | |
columns_to_show.append('num_spaces') | |
filtered_df = filtered_df[filtered_df['num_spaces'] != 0] | |
if "github" in hf_options: | |
if 'github' not in columns_to_show: | |
columns_to_show.append('github') | |
filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())] | |
# Apply conference filtering | |
if conference_options: | |
columns_to_show.remove("date") | |
columns_to_show.remove("arxiv_id") | |
if 'conference_name' not in columns_to_show: | |
columns_to_show.append('conference_name') | |
if 'proceedings' not in columns_to_show: | |
columns_to_show.append('proceedings') | |
if 'type' not in columns_to_show: | |
columns_to_show.append('type') | |
if 'id' not in columns_to_show: | |
columns_to_show.append('id') | |
# If "In proceedings" is selected | |
if "In proceedings" in conference_options: | |
# Filter rows where 'conference_name' is not None, not NaN, and not empty | |
filtered_df = filtered_df[ | |
filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "") | |
] | |
# For other conference options | |
other_conferences = [conf for conf in conference_options if conf != "In proceedings"] | |
if other_conferences: | |
# Initialize filter series | |
conference_filter = pd.Series(False, index=filtered_df.index) | |
for conference in other_conferences: | |
# Filter rows where 'conference_name' contains the conference string (case-insensitive) | |
conference_filter |= ( | |
filtered_df['conference_name'].notna() & | |
(filtered_df['conference_name'].str.lower() == conference.lower()) | |
) | |
filtered_df = filtered_df[conference_filter] | |
# Prettify the DataFrame | |
filtered_df = self.prettify(filtered_df) | |
# Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE | |
columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show] | |
# Select and reorder the columns | |
filtered_df = filtered_df[columns_in_order] | |
# Rename columns for display | |
filtered_df = self.rename_columns_for_display(filtered_df) | |
# Get the corresponding data types for the columns | |
new_datatypes: List[str] = [ | |
PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns | |
] | |
# Return an update object to modify the Dataframe component | |
return gr.update(value=filtered_df, datatype=new_datatypes) | |
def _get_original_column_name(self, display_column_name: str) -> str: | |
""" | |
Retrieve the original column name given a display column name. | |
Args: | |
display_column_name (str): The display name of the column. | |
Returns: | |
str: The original name of the column. | |
""" | |
inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()} | |
return inverse_map.get(display_column_name, display_column_name) | |