File size: 4,380 Bytes
c2fc17c
415c066
 
 
 
8f895f2
 
415c066
 
8f895f2
 
415c066
8f895f2
 
 
 
8573c13
 
 
 
 
 
8f895f2
 
415c066
 
8f895f2
 
 
 
 
415c066
8f895f2
 
415c066
 
8f895f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551e737
 
 
 
8f895f2
551e737
 
8f895f2
 
c8a32d4
 
8f895f2
 
 
 
 
 
 
 
415c066
8f895f2
415c066
8f895f2
415c066
8f895f2
 
 
 
 
 
 
415c066
8f895f2
415c066
8f895f2
 
 
 
 
 
 
415c066
8f895f2
d01b293
 
 
f2ef607
d01b293
 
 
 
c8cfc43
d01b293
5964707
 
 
 
baf74d8
 
5964707
d01b293
baf74d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cleaning as clean
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import numpy as np
import json
from sklearn.base import BaseEstimator, TransformerMixin
import os


class Embedder(BaseEstimator, TransformerMixin):
    """A class to handle creating sentence transformer embeddings from a clean arxiv dataset."""

    def fit(self, X, y=None):
        return self

    def transform(
        self,
        X=None,
        y=None,
        model_name=None,
        load_from_file=False,
        path_to_embeddings=None,
    ):
        """Either generates embeddings from an clean ArXivData instance or loads embeddings from file.

        Args:
            X: ArXivData instance that has been cleaned
            y: Labels. Defaults to None.
            model_name: Sentence transformer model used to generate embeddings. Defaults to None.
            load_from_file: Boolean used to specify whether to calculate embeddings or load from file. Defaults to False.
            path_to_embeddings: path to the location to save embeddings to or load embeddings from. Defaults to None.

        Raises:
            Exception: Raises exception if the load_from_file is True without a specified path to load from.
        """

        if load_from_file:
            if not path_to_embeddings:
                raise Exception("You must specify a path to store the embeddings.")
            X.embeddings = pd.read_feather(path_to_embeddings).to_numpy()
        else:
            ## Generate embeddings from X and save as an attribute of X.

            if not model_name:
                raise Exception(
                    "You must specify the sentence transformer model to use."
                )

            doc_strings = (X.metadata.doc_strings).to_list()
            model = SentenceTransformer(model_name)
            embeddings = model.encode(doc_strings, show_progress_bar=True)
            X.embeddings = embeddings

            ## Save the embeddings to the specified path, or, if no path is specified, use the default path
            ## default path = ./model_name_embeddings.feather

            embeddings_df = pd.DataFrame(embeddings)
            embeddings_df.columns = [
                str(column_name) for column_name in embeddings_df.columns
            ]

            if not path_to_embeddings:
                path_to_embeddings = os.path.join(
                    os.getcwd(), f"{model_name}_embeddings.feather"
                )

            embeddings_df.to_feather(path_to_embeddings)


class ComputeMSCLabels(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None, path_to_embeddings=None):
        tag_to_embedding_dict = clean.msc_encoded_dict()

        X["scored_tags"] = np.nan

        X_tagged_rows = X[X.msc_tags.notna()]

        X_tagged_rows["tag_embeddings"] = X_tagged_rows.msc_tags.apply(
            clean.list_mapper, dictionary=tag_to_embedding_dict
        )
        tag_scores = X_tagged_rows.apply(
            self.get_tag_semantic_scores, path_to_embeddings=path_to_embeddings, axis=1
        )
        X.scored_tags[X.metadata.msc_tags.notna()] = tag_scores

        return X

    def get_tag_semantic_scores(self, metadata_row, path_to_embeddings):
        embeddings = pd.read_feather(path_to_embeddings).to_numpy()
        results = util.semantic_search(
            query_embeddings=list(embeddings[metadata_row.doc_strings.index, :]),
            corpus_embeddings=metadata_row.tag_embeddings,
            top_k=50,
        )

        return results[0]


def generate_tag_embeddings(model_name, path_to_tag_dict, path_to_save_embeddings):
    model = SentenceTransformer(model_name)
    with open(path_to_tag_dict, "r") as file:
        dict_string = file.read()
        tag_dict = json.loads(dict_string)

    tag_name_list = list(set(tag_dict.values()))
    embedded_tag_names = model.encode(sentences=tag_name_list, show_progress_bar=True)
    embedded_tag_names_df = pd.DataFrame(embedded_tag_names)
    embedded_tag_names_df.columns = [
        str(name) for name in embedded_tag_names_df.columns
    ]
    embedded_tag_names_df.index = tag_name_list
    embedded_tag_names_df.to_parquet(path_to_save_embeddings, index=True)


def load_tag_embeddings(path_to_tag_embeddings):
    return pd.read_parquet(path_to_tag_embeddings)