Michael-Geis commited on
Commit
1c04d6f
1 Parent(s): 8cfea01

modified dependencies

Browse files
Files changed (3) hide show
  1. postprocess.py +28 -31
  2. preprocess.py +0 -3
  3. requirements.txt +3 -6
postprocess.py CHANGED
@@ -1,7 +1,4 @@
1
- from sklearn.base import TransformerMixin, BaseEstimator
2
  import json
3
- import pandas as pd
4
- import numpy as np
5
 
6
 
7
  def postprocess(model_output):
@@ -15,38 +12,38 @@ def postprocess(model_output):
15
  return sorted([subject_dict[tag] for tag in predicted_tags])
16
 
17
 
18
- class ModelOutputDecoder(BaseEstimator, TransformerMixin):
19
- def fit(self, X, y=None):
20
- return self
21
 
22
- def transform(self, X, y=None):
23
- if y is None:
24
- return X
25
 
26
- ## Load label dictionary
27
- with open("./data/arxiv-label-dict.json") as file:
28
- string_dict = file.read()
29
- label_dict = json.loads(string_dict)
30
- col_list = list(label_dict.keys())
31
 
32
- def decode_label(label):
33
- ## For a row of y (individual label) returns the list of english subjects corresponding to this label
34
- return [label_dict[col_list[index]] for index in np.where(label == 1)[0]]
35
 
36
- num_rows, _ = y.shape
37
 
38
- decoded_labels = []
39
- for i in range(num_rows):
40
- decoded_labels.append(decode_label(y[i, :]))
41
 
42
- decoded_labels_as_series = pd.Series(
43
- decoded_labels, name="decoded_labels", index=X.index
44
- )
45
 
46
- return pd.merge(
47
- left=X,
48
- left_index=True,
49
- right=decoded_labels_as_series,
50
- right_index=True,
51
- validate="1:1",
52
- )
 
 
1
  import json
 
 
2
 
3
 
4
  def postprocess(model_output):
 
12
  return sorted([subject_dict[tag] for tag in predicted_tags])
13
 
14
 
15
+ # class ModelOutputDecoder(BaseEstimator, TransformerMixin):
16
+ # def fit(self, X, y=None):
17
+ # return self
18
 
19
+ # def transform(self, X, y=None):
20
+ # if y is None:
21
+ # return X
22
 
23
+ # ## Load label dictionary
24
+ # with open("./data/arxiv-label-dict.json") as file:
25
+ # string_dict = file.read()
26
+ # label_dict = json.loads(string_dict)
27
+ # col_list = list(label_dict.keys())
28
 
29
+ # def decode_label(label):
30
+ # ## For a row of y (individual label) returns the list of english subjects corresponding to this label
31
+ # return [label_dict[col_list[index]] for index in np.where(label == 1)[0]]
32
 
33
+ # num_rows, _ = y.shape
34
 
35
+ # decoded_labels = []
36
+ # for i in range(num_rows):
37
+ # decoded_labels.append(decode_label(y[i, :]))
38
 
39
+ # decoded_labels_as_series = pd.Series(
40
+ # decoded_labels, name="decoded_labels", index=X.index
41
+ # )
42
 
43
+ # return pd.merge(
44
+ # left=X,
45
+ # left_index=True,
46
+ # right=decoded_labels_as_series,
47
+ # right_index=True,
48
+ # validate="1:1",
49
+ # )
preprocess.py CHANGED
@@ -1,6 +1,3 @@
1
- from sklearn.pipeline import Pipeline
2
- from sklearn.base import BaseEstimator, TransformerMixin
3
- import pandas as pd
4
  import regex
5
 
6
 
 
 
 
 
1
  import regex
2
 
3
 
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
- numpy
2
- pandas
3
- scikit-learn
4
- scikit-multilearn
5
  arxiv
 
 
6
  transformers
7
- torch
8
- datasets
 
 
 
 
 
1
  arxiv
2
+ regex
3
+ scikit-learn
4
  transformers
5
+ torch