leavoigt commited on
Commit
af925e8
1 Parent(s): 7329ea1

Update utils/vulnerability_classifier.py

Browse files
Files changed (1) hide show
  1. utils/vulnerability_classifier.py +47 -4
utils/vulnerability_classifier.py CHANGED
@@ -9,7 +9,51 @@ import streamlit as st
9
  from transformers import pipeline
10
  from setfit import SetFitModel
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @st.cache_resource
14
  def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
15
  """
@@ -89,12 +133,11 @@ def vulnerability_classification(haystack_doc:pd.DataFrame,
89
 
90
  predictions = classifier_model(list(haystack_doc.text))
91
 
92
- haystack_doc['Vulnerability Label'] = predictions
93
 
94
- # list_ = []
95
- # for i in range(len(predictions)):
96
 
97
- # temp = predictions[i]
98
  # placeholder = {}
99
  # for j in range(len(temp)):
100
  # placeholder[temp[j]['label']] = temp[j]['score']
 
9
  from transformers import pipeline
10
  from setfit import SetFitModel
11
 
12
+ label_dict= {0: 'Agricultural communities',
13
+ 1: 'Children',
14
+ 2: 'Coastal communities',
15
+ 3: 'Ethnic, racial or other minorities',
16
+ 4: 'Fishery communities',
17
+ 5: 'Informal sector workers',
18
+ 6: 'Members of indigenous and local communities',
19
+ 7: 'Migrants and displaced persons',
20
+ 8: 'Older persons',
21
+ 9: 'Other',
22
+ 10: 'Persons living in poverty',
23
+ 11: 'Persons with disabilities',
24
+ 12: 'Persons with pre-existing health conditions',
25
+ 13: 'Residents of drought-prone regions',
26
+ 14: 'Rural populations',
27
+ 15: 'Sexual minorities (LGBTQI+)',
28
+ 16: 'Urban populations',
29
+ 17: 'Women and other genders'}
30
 
31
+ def getlabels(preds):
32
+ # Get label names
33
+ preds_list = preds.tolist()
34
+
35
+ predictions_names=[]
36
+
37
+ # loop through each prediction
38
+ for ele in preds_list:
39
+
40
+ # see if there is a value 1 and retrieve index
41
+ try:
42
+ index_of_one = ele.index(1)
43
+ except ValueError:
44
+ index_of_one = "NA"
45
+
46
+ # Retrieve the name of the label (if no prediction made = NA)
47
+ if index_of_one != "NA":
48
+ name = label_dict[index_of_one]
49
+ else:
50
+ name = "NA"
51
+
52
+ # Append name to list
53
+ predictions_names.append(name)
54
+
55
+ return predictions_names
56
+
57
  @st.cache_resource
58
  def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
59
  """
 
133
 
134
  predictions = classifier_model(list(haystack_doc.text))
135
 
136
+
137
 
138
+ pred_labels = getlabels(predictions)
 
139
 
140
+ haystack_doc['Vulnerability Label'] = pred_labels
141
  # placeholder = {}
142
  # for j in range(len(temp)):
143
  # placeholder[temp[j]['label']] = temp[j]['score']