ArXivRecommenderSystem / embedding.py
Michael-Geis
reorganized
7cc8002
import cleaning as clean
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import numpy as np
import json
from sklearn.base import BaseEstimator, TransformerMixin
import os
class Embedder(BaseEstimator, TransformerMixin):
"""Takes a list of clean strings and outputs a numpy array of their embeddings generated by the ST model model_name."""
def __init__(self, model_name) -> None:
super().__init__()
self.model_name = model_name
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
encoder = SentenceTransformer(self.model_name)
embedded_documents = encoder.encode(sentences=X)
return embedded_documents
class FullEmbedder(BaseEstimator, TransformerMixin):
"""A class to handle creating sentence transformer embeddings from a clean arxiv dataset."""
def fit(self, X, y=None):
return self
def transform(
self,
X=None,
y=None,
model_name=None,
load_from_file=False,
path_to_embeddings=None,
):
"""Either generates embeddings from an clean ArXivData instance or loads embeddings from file.
Args:
X: ArXivData instance that has been cleaned
y: Labels. Defaults to None.
model_name: Sentence transformer model used to generate embeddings. Defaults to None.
load_from_file: Boolean used to specify whether to calculate embeddings or load from file. Defaults to False.
path_to_embeddings: path to the location to save embeddings to or load embeddings from. Defaults to None.
Raises:
Exception: Raises exception if the load_from_file is True without a specified path to load from.
"""
if load_from_file:
if not path_to_embeddings:
raise Exception("You must specify a path to store the embeddings.")
X.embeddings = pd.read_feather(path_to_embeddings).to_numpy()
return X
else:
## Generate embeddings from X and save as an attribute of X.
if not model_name:
raise Exception(
"You must specify the sentence transformer model to use."
)
doc_strings = (X.metadata.doc_strings).to_list()
model = SentenceTransformer(model_name)
embeddings = model.encode(doc_strings, show_progress_bar=True)
X.embeddings = embeddings
## Save the embeddings to the specified path, or, if no path is specified, use the default path
## default path = ./model_name_embeddings.feather
embeddings_df = pd.DataFrame(embeddings)
embeddings_df.columns = [
str(column_name) for column_name in embeddings_df.columns
]
if not path_to_embeddings:
path_to_embeddings = os.path.join(
os.getcwd(), f"{model_name}_embeddings.feather"
)
embeddings_df.to_feather(path_to_embeddings)
return X
class ComputeMSCLabels(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
return self
def transform(self, X, y=None, path_to_embeddings=None):
tag_to_embedding_dict = clean.msc_encoded_dict()
X["scored_tags"] = np.nan
X_tagged_rows = X[X.msc_tags.notna()]
X_tagged_rows["tag_embeddings"] = X_tagged_rows.msc_tags.apply(
clean.list_mapper, dictionary=tag_to_embedding_dict
)
tag_scores = X_tagged_rows.apply(
self.get_tag_semantic_scores, path_to_embeddings=path_to_embeddings, axis=1
)
X.scored_tags[X.metadata.msc_tags.notna()] = tag_scores
return X
def get_tag_semantic_scores(self, metadata_row, path_to_embeddings):
embeddings = pd.read_feather(path_to_embeddings).to_numpy()
results = util.semantic_search(
query_embeddings=list(embeddings[metadata_row.doc_strings.index, :]),
corpus_embeddings=metadata_row.tag_embeddings,
top_k=50,
)
return results[0]
def generate_tag_embeddings(model_name, path_to_tag_dict, path_to_save_embeddings):
model = SentenceTransformer(model_name)
with open(path_to_tag_dict, "r") as file:
dict_string = file.read()
tag_dict = json.loads(dict_string)
tag_name_list = list(set(tag_dict.values()))
embedded_tag_names = model.encode(sentences=tag_name_list, show_progress_bar=True)
embedded_tag_names_df = pd.DataFrame(embedded_tag_names)
embedded_tag_names_df.columns = [
str(name) for name in embedded_tag_names_df.columns
]
embedded_tag_names_df.index = tag_name_list
embedded_tag_names_df.to_parquet(path_to_save_embeddings, index=True)
def load_tag_embeddings(path_to_tag_embeddings):
return pd.read_parquet(path_to_tag_embeddings)