leavoigt commited on
Commit
e000f68
·
1 Parent(s): 9a937ee

Update utils/vulnerability_classifier.py

Browse files
Files changed (1) hide show
  1. utils/vulnerability_classifier.py +10 -12
utils/vulnerability_classifier.py CHANGED
@@ -75,16 +75,16 @@ def vulnerability_classification(haystack_doc:pd.DataFrame,
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
  logging.info("Working on vulnerability Identification")
78
- haystack_doc['Indicator Label'] = 'NA'
79
- haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
80
 
81
- df1 = haystack_doc[haystack_doc['PA_check'] == True]
82
- df = haystack_doc[haystack_doc['PA_check'] == False]
83
  if not classifier_model:
84
  classifier_model = st.session_state['vulnerability_classifier']
85
 
86
- predictions = classifier_model(list(df1.text))
87
-
88
  list_ = []
89
  for i in range(len(predictions)):
90
 
@@ -99,11 +99,9 @@ def vulnerability_classification(haystack_doc:pd.DataFrame,
99
  truth_df = truth_df.astype(float) >= threshold
100
  truth_df = truth_df.astype(str)
101
  categories = list(truth_df.columns)
102
- truth_df['Indicator Label'] = truth_df.apply(lambda x: {i if x[i]=='True' else
103
  None for i in categories}, axis=1)
104
- truth_df['Indicator Label'] = truth_df.apply(lambda x: list(x['Indicator Label']
105
  -{None}),axis=1)
106
- df1['Indicator Label'] = list(truth_df['Indicator Label'])
107
- df = pd.concat([df,df1])
108
- df = df.drop(columns = ['PA_check'])
109
- return df
 
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
  logging.info("Working on vulnerability Identification")
78
+ haystack_doc['Sector Label'] = 'NA'
79
+ # haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
80
 
81
+ # df1 = haystack_doc[haystack_doc['PA_check'] == True]
82
+ # df = haystack_doc[haystack_doc['PA_check'] == False]
83
  if not classifier_model:
84
  classifier_model = st.session_state['vulnerability_classifier']
85
 
86
+ predictions = classifier_model(list(haystack_doc.text))
87
+
88
  list_ = []
89
  for i in range(len(predictions)):
90
 
 
99
  truth_df = truth_df.astype(float) >= threshold
100
  truth_df = truth_df.astype(str)
101
  categories = list(truth_df.columns)
102
+ truth_df['Vulnerability Label'] = truth_df.apply(lambda x: {i if x[i]=='True' else
103
  None for i in categories}, axis=1)
104
+ truth_df['Vulnerability Label'] = truth_df.apply(lambda x: list(x['Vulnerability Label']
105
  -{None}),axis=1)
106
+ haystack_doc['Vulnerability Label'] = list(truth_df['Vulnerability Label'])
107
+ return haystack_doc