leavoigt commited on
Commit
5d407bd
1 Parent(s): 7bb152d

Update utils/target_classifier.py

Browse files
Files changed (1) hide show
  1. utils/target_classifier.py +32 -0
utils/target_classifier.py CHANGED
@@ -14,6 +14,38 @@ _lab_dict = {
14
  '1':'YES',
15
  }
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @st.cache_resource
18
  def load_targetClassifier(config_file:str = None, classifier_name:str = None):
19
  """
 
14
  '1':'YES',
15
  }
16
 
17
+ def get_target_labels(preds):
18
+
19
+ """
20
+ Function that takes the numerical predictions as an input and returns a list of the labels.
21
+
22
+ """
23
+
24
+ # Get label names
25
+ preds_list = preds.tolist()
26
+
27
+ predictions_names=[]
28
+
29
+ # loop through each prediction
30
+ for ele in preds_list:
31
+
32
+ # see if there is a value 1 and retrieve index
33
+ try:
34
+ index_of_one = ele.index(1)
35
+ except ValueError:
36
+ index_of_one = "NA"
37
+
38
+ # Retrieve the name of the label (if no prediction made = NA)
39
+ if index_of_one != "NA":
40
+ name = label_dict[index_of_one]
41
+ else:
42
+ name = "Other"
43
+
44
+ # Append name to list
45
+ predictions_names.append(name)
46
+
47
+ return predictions_names
48
+
49
  @st.cache_resource
50
  def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
  """