import pandas as pd from typing import List, Dict, Optional from constants import ( DATASET_ARXIV_SCAN_PAPERS, DATASET_CONFERENCE_PAPERS, DATASET_COMMUNITY_SCIENCE, NEURIPS_ICO, DATASET_PAPER_CENTRAL, COLM_ICO, DEFAULT_ICO, MICCAI24ICO, ) import gradio as gr from utils import load_and_process import numpy as np class PaperCentral: """ A class to manage and process paper data for display in a Gradio Dataframe component. """ CONFERENCES = [ "ACL2023", "ACL2024", "COLING2024", "CVPR2023", "CVPR2024", "ECCV2024", "EMNLP2023", "NAACL2023", "NeurIPS2023", "NeurIPS2023 D&B", "COLM2024", "MICCAI2024", ] CONFERENCES_ICONS = { "ACL2023": 'https://aclanthology.org/aclicon.ico', "ACL2024": 'https://aclanthology.org/aclicon.ico', "COLING2024": 'https://aclanthology.org/aclicon.ico', "CVPR2023": "https://openaccess.thecvf.com/favicon.ico", "CVPR2024": "https://openaccess.thecvf.com/favicon.ico", "ECCV2024": "https://openaccess.thecvf.com/favicon.ico", "EMNLP2023": 'https://aclanthology.org/aclicon.ico', "NAACL2023": 'https://aclanthology.org/aclicon.ico', "NeurIPS2023": NEURIPS_ICO, "NeurIPS2023 D&B": NEURIPS_ICO, "COLM2024": COLM_ICO, "MICCAI2024": MICCAI24ICO, } # Class-level constants defining columns and their data types COLUMNS_START_PAPER_PAGE: List[str] = [ 'date', 'arxiv_id', 'paper_page', 'title', ] COLUMNS_ORDER_PAPER_PAGE: List[str] = [ 'date', 'arxiv_id', 'paper_page', 'num_models', 'num_datasets', 'num_spaces', 'upvotes', 'num_comments', 'github', 'conference_name', 'id', 'type', 'proceedings', 'title', 'authors', ] DATATYPES: Dict[str, str] = { 'date': 'str', 'arxiv_id': 'markdown', 'paper_page': 'markdown', 'upvotes': 'str', 'num_comments': 'str', 'num_models': 'markdown', 'num_datasets': 'markdown', 'num_spaces': 'markdown', 'github': 'markdown', 'title': 'str', 'proceedings': 'markdown', 'conference_name': 'str', 'id': 'str', 'type': 'str', 'authors': 'str', } # Mapping for renaming columns for display purposes COLUMN_RENAME_MAP: Dict[str, str] = { 'num_models': 'models', 'num_spaces': 'spaces', 'num_datasets': 'datasets', 'conference_name': 'venue', } def __init__(self): """ Initialize the PaperCentral class by loading and processing the datasets. """ self.df_raw: pd.DataFrame = self.get_df() self.df_prettified: pd.DataFrame = self.prettify(self.df_raw) @staticmethod def get_columns_order(columns: List[str]) -> List[str]: """ Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE. Args: columns (List[str]): List of column names to order. Returns: List[str]: Ordered list of column names. """ return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns] @staticmethod def get_columns_datatypes(columns: List[str]) -> List[str]: """ Get data types for the specified columns. Args: columns (List[str]): List of column names. Returns: List[str]: List of data types corresponding to the columns. """ return [PaperCentral.DATATYPES[c] for c in columns] @staticmethod def get_df() -> pd.DataFrame: """ Load and merge datasets to create the raw DataFrame. Returns: pd.DataFrame: The merged and processed DataFrame. """ # Load datasets paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[ ['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models', 'num_datasets', 'num_spaces', 'id', 'proceedings', 'type', 'conference_name', 'title', 'paper_page', 'authors'] ] return paper_central_df @staticmethod def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame: """ Format the date column in the DataFrame to 'YYYY-MM-DD'. Args: df (pd.DataFrame): The DataFrame to format. date_column (str): The name of the date column. Returns: pd.DataFrame: The DataFrame with the formatted date column. """ df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d') return df @staticmethod def prettify(df: pd.DataFrame) -> pd.DataFrame: """ Prettify the DataFrame by adding markdown links and sorting. Args: df (pd.DataFrame): The DataFrame to prettify. Returns: pd.DataFrame: The prettified DataFrame. """ def update_row(row: pd.Series) -> pd.Series: """ Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns. Args: row (pd.Series): A row from the DataFrame. Returns: pd.Series: The updated row. """ # Process 'num_models' column if ( 'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"] and float(row['num_models']) > 0 ): num_models = int(float(row['num_models'])) row['num_models'] = ( f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})" ) if ( 'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"] and float(row['num_datasets']) > 0 ): num_datasets = int(float(row['num_datasets'])) row['num_datasets'] = ( f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})" ) if ( 'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"] and float(row['num_spaces']) > 0 ): num_spaces = int(float(row['num_spaces'])) row['num_spaces'] = ( f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})" ) if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']: image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO) style = "display:inline-block; vertical-align:middle; width: 16px; height:16px" row['proceedings'] = ( f"" f"proc_page" ) #### ### This should be processed last :) #### # Add markdown link to 'paper_page' if it exists if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']: row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})" # Add image and link to 'arxiv_id' if it exists if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']: image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png" style = "display:inline-block; vertical-align:middle;" row['arxiv_id'] = ( f"" f"arxiv_page" ) # Add image and link to 'arxiv_id' if it exists if 'github' in row and pd.notna(row['github']) and row["github"]: image_url = "https://github.githubassets.com/favicons/favicon.png" style = "display:inline-block; vertical-align:middle;width:16px;" row['github'] = ( f"" f"github" ) return row df = df.copy() # Apply the update_row function to each row prettified_df: pd.DataFrame = df.apply(update_row, axis=1) return prettified_df def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame: """ Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes. Args: df (pd.DataFrame): The DataFrame whose columns need to be renamed. Returns: pd.DataFrame: The DataFrame with renamed columns. """ return df.rename(columns=self.COLUMN_RENAME_MAP) def filter( self, selected_date: Optional[str] = None, cat_options: Optional[List[str]] = None, hf_options: Optional[List[str]] = None, conference_options: Optional[List[str]] = None, author_search_input: Optional[str] = None, title_search_input: Optional[str] = None, ) -> gr.update: """ Filter the DataFrame based on selected date and options, and prepare it for display. Args: selected_date (Optional[str]): The date to filter the DataFrame. hf_options (Optional[List[str]]): List of options selected by the user. conference_options (Optional[List[str]]): List of conference options selected by the user. Returns: gr.Update: An update object for the Gradio Dataframe component. """ filtered_df: pd.DataFrame = self.df_raw.copy() # Start with the initial columns to display columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy() if title_search_input: if 'title' not in columns_to_show: columns_to_show.append('authors') search_string = title_search_input.lower() def title_match(title): if isinstance(title, str): # If authors_list is a single string return search_string in title.lower() else: # Handle unexpected data types return False filtered_df = filtered_df[filtered_df['title'].apply(title_match)] if author_search_input: if 'authors' not in columns_to_show: columns_to_show.append('authors') search_string = author_search_input.lower() def author_matches(authors_list): # Check if authors_list is None or empty if authors_list is None or len(authors_list) == 0: return False # Check if authors_list is an iterable (list, tuple, Series, or ndarray) if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)): return any( isinstance(author, str) and search_string in author.lower() for author in authors_list ) elif isinstance(authors_list, str): # If authors_list is a single string return search_string in authors_list.lower() else: # Handle unexpected data types return False filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)] if cat_options: options = [o.replace(".*", "") for o in cat_options] # Initialize filter series conference_filter = pd.Series(False, index=filtered_df.index) for option in options: # Filter rows where 'conference_name' contains the conference string (case-insensitive) conference_filter |= ( filtered_df['primary_category'].notna() & filtered_df['primary_category'].str.contains(option, case=False) ) filtered_df = filtered_df[conference_filter] # Date if selected_date and not conference_options: selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d') filtered_df = filtered_df[filtered_df['date'] == selected_date] # HF options if hf_options: if "🤗 artifacts" in hf_options: # Filter rows where 'paper_page' is not empty or NaN filtered_df = filtered_df[ (filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna()) ] # Add 'upvotes' column if not already in columns_to_show if 'upvotes' not in columns_to_show: columns_to_show.append('upvotes') # Add 'num_models' column if not already in columns_to_show if 'num_models' not in columns_to_show: columns_to_show.append('num_models') if 'num_datasets' not in columns_to_show: columns_to_show.append('num_datasets') if 'num_spaces' not in columns_to_show: columns_to_show.append('num_spaces') if "datasets" in hf_options: if 'num_datasets' not in columns_to_show: columns_to_show.append('num_datasets') filtered_df = filtered_df[filtered_df['num_datasets'] != 0] if "models" in hf_options: if 'num_models' not in columns_to_show: columns_to_show.append('num_models') filtered_df = filtered_df[filtered_df['num_models'] != 0] if "spaces" in hf_options: if 'num_spaces' not in columns_to_show: columns_to_show.append('num_spaces') filtered_df = filtered_df[filtered_df['num_spaces'] != 0] if "github" in hf_options: if 'github' not in columns_to_show: columns_to_show.append('github') filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())] # Apply conference filtering if conference_options: columns_to_show.remove("date") columns_to_show.remove("arxiv_id") if 'conference_name' not in columns_to_show: columns_to_show.append('conference_name') if 'proceedings' not in columns_to_show: columns_to_show.append('proceedings') if 'type' not in columns_to_show: columns_to_show.append('type') if 'id' not in columns_to_show: columns_to_show.append('id') # If "In proceedings" is selected if "In proceedings" in conference_options: # Filter rows where 'conference_name' is not None, not NaN, and not empty filtered_df = filtered_df[ filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "") ] # For other conference options other_conferences = [conf for conf in conference_options if conf != "In proceedings"] if other_conferences: # Initialize filter series conference_filter = pd.Series(False, index=filtered_df.index) for conference in other_conferences: # Filter rows where 'conference_name' contains the conference string (case-insensitive) conference_filter |= ( filtered_df['conference_name'].notna() & (filtered_df['conference_name'].str.lower() == conference.lower()) ) filtered_df = filtered_df[conference_filter] # Prettify the DataFrame filtered_df = self.prettify(filtered_df) # Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show] # Select and reorder the columns filtered_df = filtered_df[columns_in_order] # Rename columns for display filtered_df = self.rename_columns_for_display(filtered_df) # Get the corresponding data types for the columns new_datatypes: List[str] = [ PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns ] # Sort rows to display entries with 'paper_page' first if 'paper_page' in filtered_df.columns: filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "") filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True) filtered_df.drop(columns='has_paper_page', inplace=True) # Return an update object to modify the Dataframe component return gr.update(value=filtered_df, datatype=new_datatypes) def _get_original_column_name(self, display_column_name: str) -> str: """ Retrieve the original column name given a display column name. Args: display_column_name (str): The display name of the column. Returns: str: The original name of the column. """ inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()} return inverse_map.get(display_column_name, display_column_name)