Spaces:
Runtime error
Runtime error
Michael-Geis
commited on
Commit
•
415c066
1
Parent(s):
9c78a22
created embedding class and updated log
Browse files- data_cleaning.py +1 -1
- embedding.py +43 -0
data_cleaning.py
CHANGED
@@ -305,7 +305,7 @@ def cats_to_msc(cat_list):
|
|
305 |
|
306 |
def msc_encoded_dict():
|
307 |
encoded_tags = pd.read_parquet("./data/msc_mini_embeddings.parquet").to_numpy()
|
308 |
-
return {k: v for (k, v) in zip(msc_tags().
|
309 |
|
310 |
|
311 |
def doc_encoded_dict():
|
|
|
305 |
|
306 |
def msc_encoded_dict():
|
307 |
encoded_tags = pd.read_parquet("./data/msc_mini_embeddings.parquet").to_numpy()
|
308 |
+
return {k: v for (k, v) in zip(msc_tags().keys(), encoded_tags)}
|
309 |
|
310 |
|
311 |
def doc_encoded_dict():
|
embedding.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import data_cleaning as clean
|
2 |
+
from sentence_transformers import SentenceTransformer, util
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
class embed:
|
9 |
+
"""A class to handle creating sentence transformer embeddings of arxiv titles and abstracts."""
|
10 |
+
|
11 |
+
def prepare_sentences(dataset=pd.DataFrame()):
|
12 |
+
"""cleans title and abstract of each paper and concatenates them.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
dataset: arxiv dataset
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
list in which entry i is cleaned and concatenated title and abstract of article i.
|
19 |
+
"""
|
20 |
+
|
21 |
+
clean_dataset = clean.clean_title_abstracts(dataset)
|
22 |
+
return (clean_dataset.title + " " + clean_dataset.abstract).to_list()
|
23 |
+
|
24 |
+
def create_sentence_embeddings(self, dataset, model_name):
|
25 |
+
model = SentenceTransformer(model_name)
|
26 |
+
sentences = self.prepare_sentences(dataset)
|
27 |
+
embedding_array = model.encode(sentences=sentences, show_progress_bar=True)
|
28 |
+
|
29 |
+
return pd.DataFrame(embedding_array).join(dataset.id)
|
30 |
+
|
31 |
+
## Create series object in which each entry is NAN or the list of embedded tags
|
32 |
+
|
33 |
+
def rank_msc_tags(self, dataset):
|
34 |
+
tag_map = clean.msc_encoded_dict()
|
35 |
+
# Get the list of embedded tags for all tagged rows in a new column
|
36 |
+
embedded_tags = dataset.msc_tags
|
37 |
+
dataset['embedded_tags'] = embedded_tags[
|
38 |
+
dataset.msc_tags.notna()
|
39 |
+
].apply(lambda x: [tag_map[tag] for tag in x])
|
40 |
+
|
41 |
+
## Finish this tomorrow
|
42 |
+
|
43 |
+
dataset['semantic_tag_score'] = dataset.apply( ,axis=1)
|