leavoigt commited on
Commit
fd9280e
1 Parent(s): 5d407bd

Update utils/vulnerability_classifier.py

Browse files
Files changed (1) hide show
  1. utils/vulnerability_classifier.py +5 -22
utils/vulnerability_classifier.py CHANGED
@@ -28,7 +28,7 @@ label_dict= {0: 'Agricultural communities',
28
  16: 'Urban populations',
29
  17: 'Women and other genders'}
30
 
31
- def getlabels(preds):
32
 
33
  """
34
  Function that takes the numerical predictions as an input and returns a list of the labels.
@@ -37,28 +37,11 @@ def getlabels(preds):
37
 
38
  # Get label names
39
  preds_list = preds.tolist()
 
 
 
40
 
41
- predictions_names=[]
42
-
43
- # loop through each prediction
44
- for ele in preds_list:
45
-
46
- # see if there is a value 1 and retrieve index
47
- try:
48
- index_of_one = ele.index(1)
49
- except ValueError:
50
- index_of_one = "NA"
51
-
52
- # Retrieve the name of the label (if no prediction made = NA)
53
- if index_of_one != "NA":
54
- name = label_dict[index_of_one]
55
- else:
56
- name = "Other"
57
-
58
- # Append name to list
59
- predictions_names.append(name)
60
-
61
- return predictions_names
62
 
63
  @st.cache_resource
64
  def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
 
28
  16: 'Urban populations',
29
  17: 'Women and other genders'}
30
 
31
+ def get_vulnerability_labels(preds):
32
 
33
  """
34
  Function that takes the numerical predictions as an input and returns a list of the labels.
 
37
 
38
  # Get label names
39
  preds_list = preds.tolist()
40
+
41
+ # Get the name of the group where the prediction is equal to "1"
42
+ result = [label_dict[key] for key, value in enumerate(preds_list) if value == 1]
43
 
44
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @st.cache_resource
47
  def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):