Michael-Geis commited on
Commit
415c066
1 Parent(s): 9c78a22

created embedding class and updated log

Browse files
Files changed (2) hide show
  1. data_cleaning.py +1 -1
  2. embedding.py +43 -0
data_cleaning.py CHANGED
@@ -305,7 +305,7 @@ def cats_to_msc(cat_list):
305
 
306
  def msc_encoded_dict():
307
  encoded_tags = pd.read_parquet("./data/msc_mini_embeddings.parquet").to_numpy()
308
- return {k: v for (k, v) in zip(msc_tags().values(), encoded_tags)}
309
 
310
 
311
  def doc_encoded_dict():
 
305
 
306
  def msc_encoded_dict():
307
  encoded_tags = pd.read_parquet("./data/msc_mini_embeddings.parquet").to_numpy()
308
+ return {k: v for (k, v) in zip(msc_tags().keys(), encoded_tags)}
309
 
310
 
311
  def doc_encoded_dict():
embedding.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import data_cleaning as clean
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import pandas as pd
4
+ import numpy as np
5
+ import json
6
+
7
+
8
+ class embed:
9
+ """A class to handle creating sentence transformer embeddings of arxiv titles and abstracts."""
10
+
11
+ def prepare_sentences(dataset=pd.DataFrame()):
12
+ """cleans title and abstract of each paper and concatenates them.
13
+
14
+ Args:
15
+ dataset: arxiv dataset
16
+
17
+ Returns:
18
+ list in which entry i is cleaned and concatenated title and abstract of article i.
19
+ """
20
+
21
+ clean_dataset = clean.clean_title_abstracts(dataset)
22
+ return (clean_dataset.title + " " + clean_dataset.abstract).to_list()
23
+
24
+ def create_sentence_embeddings(self, dataset, model_name):
25
+ model = SentenceTransformer(model_name)
26
+ sentences = self.prepare_sentences(dataset)
27
+ embedding_array = model.encode(sentences=sentences, show_progress_bar=True)
28
+
29
+ return pd.DataFrame(embedding_array).join(dataset.id)
30
+
31
+ ## Create series object in which each entry is NAN or the list of embedded tags
32
+
33
+ def rank_msc_tags(self, dataset):
34
+ tag_map = clean.msc_encoded_dict()
35
+ # Get the list of embedded tags for all tagged rows in a new column
36
+ embedded_tags = dataset.msc_tags
37
+ dataset['embedded_tags'] = embedded_tags[
38
+ dataset.msc_tags.notna()
39
+ ].apply(lambda x: [tag_map[tag] for tag in x])
40
+
41
+ ## Finish this tomorrow
42
+
43
+ dataset['semantic_tag_score'] = dataset.apply( ,axis=1)