## Download the arXiv metadata from Kaggle ## https://www.kaggle.com/datasets/Cornell-University/arxiv ## Requires the Kaggle API to be installed ## Using subprocess to run the Kaggle CLI commands instead of Kaggle API ## As it allows for anonymous downloads without needing to sign in import subprocess from datasets import load_dataset # To load dataset without breaking ram from multiprocessing import cpu_count # To get the number of cores from sentence_transformers import SentenceTransformer # For embedding the text import torch # For gpu import pandas as pd # Data manipulation from huggingface_hub import snapshot_download # Download previous embeddings import os # Folder and file creation from tqdm import tqdm # Progress bar tqdm.pandas() # Progress bar for pandas from mixedbread_ai.client import MixedbreadAI # For embedding the text from dotenv import dotenv_values # To load environment variables import numpy as np # For array manipulation from huggingface_hub import HfApi # To transact with huggingface.co import sys # To quit the script import datetime # get current year from time import time, sleep # To time the script from datetime import datetime # To get the current date and time # Start timer start = time() ################################################################################ # Configuration # Get current year year = str(datetime.now().year) # Flag to force download and conversion even if files already exist FORCE = True # Flag to embed the data locally, otherwise it will use mxbai api to embed LOCAL = False # Flag to upload the data to the Hugging Face Hub UPLOAD = True # Flag to binarise the data BINARY = True # Print the configuration print(f'Configuration:') print(f'Year: {year}') print(f'Force: {FORCE}') print(f'Local: {LOCAL}') print(f'Upload: {UPLOAD}') print(f'Binary: {BINARY}') ######################################## # Model to use for embedding model_name = "mixedbread-ai/mxbai-embed-large-v1" # Number of cores to use for multiprocessing num_cores = cpu_count()-1 # Setup transaction details repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus" # Import secrets config = dotenv_values(".env") def is_running_in_huggingface_space(): return "SPACE_ID" in os.environ ################################################################################ # Download the dataset # Dataset name dataset_path = 'Cornell-University/arxiv' # Download folder download_folder = 'data' # Data file path download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json' ## Download the dataset if it doesn't exist if not os.path.exists(download_file) or FORCE: print(f'Downloading {download_file}, if it exists it will be overwritten') print('Set FORCE to False to skip download if file already exists') subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip']) print(f'Downloaded {download_file}') else: print(f'{download_file} already exists, skipping download') print('Set FORCE = True to force download') ################################################################################ # Filter by year and convert to parquet # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping # Load metadata print(f"Loading json metadata") dataset = load_dataset("json", data_files= str(f"{download_file}")) # Split metadata by year # Convert to pandas print(f"Converting metadata into pandas") arxiv_metadata_all = dataset['train'].to_pandas() ######################################## # Function to extract year from arxiv id # https://info.arxiv.org/help/arxiv_identifier.html # Function to extract Month and year of publication using arxiv ID def extract_month_year(arxiv_id, what='month'): # Identify the relevant YYMM part based on the arXiv ID format yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0] # Convert the year-month string to a datetime object date = datetime.strptime(yymm, '%y%m') # Return the desired part based on the input parameter return date.strftime('%B') if what == 'month' else date.strftime('%Y') ######################################## # Add year to metadata print(f"Adding year to metadata") arxiv_metadata_all['year'] = arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year') # Filter by year print(f"Filtering metadata by year: {year}") arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year] ################################################################################ # Load Model if LOCAL: print(f"Setting up local embedding model") print("To use mxbai API, set LOCAL = False") # Make the app device agnostic device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Load a pretrained Sentence Transformer model and move it to the appropriate device print(f"Loading model {model_name} to device: {device}") model = SentenceTransformer(model_name) model = model.to(device) else: print("Setting up mxbai API client") print("To use local resources, set LOCAL = True") # Setup mxbai if is_running_in_huggingface_space(): mxbai_api_key = os.getenv("MXBAI_API_KEY") else: mxbai_api_key = config["MXBAI_API_KEY"] mxbai = MixedbreadAI(api_key=mxbai_api_key) ######################################## # Function that does the embedding def embed(input_text): if LOCAL: # Calculate embeddings by calling model.encode(), specifying the device embedding = model.encode(input_text, device=device, precision="float32") # Enforce 32-bit float precision embedding = np.array(embedding, dtype=np.float32) else: # Avoid rate limit from api sleep(0.2) # Calculate embeddings by calling mxbai.embeddings() result = mxbai.embeddings( model='mixedbread-ai/mxbai-embed-large-v1', input=input_text, normalized=True, encoding_format='float', truncation_strategy='end' ) # Enforce 32-bit float precision embedding = np.array(result.data[0].embedding, dtype=np.float32) return embedding ######################################## ################################################################################ # Gather preexisting embeddings # Subfolder in the repo of the dataset where the file is stored folder_in_repo = "data" allow_patterns = f"{folder_in_repo}/{year}.parquet" # Where to store the local copy of the dataset local_dir = repo_id # Set repo type repo_type = "dataset" # Create local directory os.makedirs(local_dir, exist_ok=True) # Download the repo snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns) try: # Gather previous embed file previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet' # Load previous_embed print(f"Loading previously embedded file: {previous_embed}") previous_embeddings = pd.read_parquet(previous_embed) except Exception as e: print(f"Errored out with: {e}") print(f"No previous embeddings found for year: {year}") print("Creating new embeddings for all papers") previous_embeddings = pd.DataFrame(columns=['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']) ######################################## # Embed the new abstracts # Find papers that are not in the previous embeddings new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])] # Drop duplicates based on the 'id' column new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True) # Number of new papers num_new_papers = len(new_papers) # What if there are no new papers? if num_new_papers == 0: print(f"No new papers found for year: {year}") print("Exiting") sys.exit() # Create a column for embeddings print(f"Creating new embeddings for: {num_new_papers} entries") new_papers["vector"] = new_papers["abstract"].progress_apply(embed) #################### print("Adding url and month columns") # Add URL column new_papers['url'] = 'https://arxiv.org/abs/' + new_papers['id'] # Add month column new_papers['month'] = new_papers['id'].progress_apply(extract_month_year, what='month') #################### print("Removing newline characters from title, authors, categories, abstract") # Remove newline characters from authors, title, abstract and categories columns new_papers['title'] = new_papers['title'].astype(str).str.replace('\n', ' ', regex=False) new_papers['authors'] = new_papers['authors'].astype(str).str.replace('\n', ' ', regex=False) new_papers['categories'] = new_papers['categories'].astype(str).str.replace('\n', ' ', regex=False) new_papers['abstract'] = new_papers['abstract'].astype(str).str.replace('\n', ' ', regex=False) #################### print("Trimming title, authors, categories, abstract") # Trim title to 512 characters new_papers['title'] = new_papers['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x) # Trim categories to 128 characters new_papers['categories'] = new_papers['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x) # Trim authors to 128 characters new_papers['authors'] = new_papers['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x) # Trim abstract to 3072 characters new_papers['abstract'] = new_papers['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x) #################### print("Concatenating previouly embedded dataframe with new embeddings") # Selecting id, vector and $meta to retain selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url'] # Merge previous embeddings and new embeddings new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]]) # Create embed folder embed_folder = f"{year}-diff-embed" os.makedirs(embed_folder, exist_ok=True) # Save the embedded file embed_filename = f'{embed_folder}/{year}.parquet' print(f"Saving newly embedded dataframe to: {embed_filename}") # Keeping index=False to avoid saving the index column as a separate column in the parquet file # This keeps milvus from throwing an error when importing the parquet file new_embeddings.to_parquet(embed_filename, index=False) ################################################################################ # Upload the new embeddings to the repo if UPLOAD: print(f"Uploading new embeddings to: {repo_id}") # Setup Hugging Face API if is_running_in_huggingface_space(): access_token = os.getenv("HF_API_KEY") else: access_token = config["HF_API_KEY"] api = HfApi(token=access_token) # Upload all files within the folder to the specified repository api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset") print(f"Upload complete for year: {year}") else: print("Not uploading new embeddings to the repo") print("To upload new embeddings, set UPLOAD to True") ################################################################################ # Binarise the data if BINARY: print(f"Binarising the data for year: {year}") print("Set BINARY = False to not binarise the embeddings") # Function to convert dense vector to binary vector def dense_to_binary(dense_vector): return np.packbits(np.where(dense_vector >= 0, 1, 0)).tobytes() # Create a folder to store binary embeddings binary_folder = f"{year}-binary-embed" os.makedirs(binary_folder, exist_ok=True) # Convert the dense vectors to binary vectors new_embeddings['vector'] = new_embeddings['vector'].progress_apply(dense_to_binary) # Save the binary embeddings to a parquet file new_embeddings.to_parquet(f'{binary_folder}/{year}.parquet', index=False) if BINARY and UPLOAD: # Setup transaction details repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus_binary" repo_type = "dataset" api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True) # Subfolder in the repo of the dataset where the file is stored folder_in_repo = "data" print(f"Uploading binary embeddings to {repo_id} from folder {binary_folder}") # Upload all files within the folder to the specified repository api.upload_folder(repo_id=repo_id, folder_path=binary_folder, path_in_repo=folder_in_repo, repo_type=repo_type) print("Upload complete") else: print("Not uploading Binary embeddings to the repo") print("To upload embeddings, set UPLOAD and BINARY both to True") ################################################################################ # Track time end = time() # Calculate and show time taken print(f"Time taken: {end - start} seconds") print("Done!")