ArXivRecommenderSystem / embedding.py
Michael-Geis
created embedding class and updated log
415c066
raw
history blame
1.53 kB
import data_cleaning as clean
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import numpy as np
import json
class embed:
"""A class to handle creating sentence transformer embeddings of arxiv titles and abstracts."""
def prepare_sentences(dataset=pd.DataFrame()):
"""cleans title and abstract of each paper and concatenates them.
Args:
dataset: arxiv dataset
Returns:
list in which entry i is cleaned and concatenated title and abstract of article i.
"""
clean_dataset = clean.clean_title_abstracts(dataset)
return (clean_dataset.title + " " + clean_dataset.abstract).to_list()
def create_sentence_embeddings(self, dataset, model_name):
model = SentenceTransformer(model_name)
sentences = self.prepare_sentences(dataset)
embedding_array = model.encode(sentences=sentences, show_progress_bar=True)
return pd.DataFrame(embedding_array).join(dataset.id)
## Create series object in which each entry is NAN or the list of embedded tags
def rank_msc_tags(self, dataset):
tag_map = clean.msc_encoded_dict()
# Get the list of embedded tags for all tagged rows in a new column
embedded_tags = dataset.msc_tags
dataset['embedded_tags'] = embedded_tags[
dataset.msc_tags.notna()
].apply(lambda x: [tag_map[tag] for tag in x])
## Finish this tomorrow
dataset['semantic_tag_score'] = dataset.apply( ,axis=1)