leavoigt commited on
Commit
95c0e35
1 Parent(s): ef9edc5

Rename utils/indicator_classifier.py to utils/vulnerability_classifier.py

Browse files
utils/{indicator_classifier.py → vulnerability_classifier.py} RENAMED
@@ -10,7 +10,7 @@ from transformers import pipeline
10
 
11
 
12
  @st.cache_resource
13
- def load_indicatorClassifier(config_file:str = None, classifier_name:str = None):
14
  """
15
  loads the document classifier using haystack, where the name/path of model
16
  in HF-hub as string is used to fetch the model object.Either configfile or
@@ -30,9 +30,9 @@ def load_indicatorClassifier(config_file:str = None, classifier_name:str = None)
30
  return
31
  else:
32
  config = getconfig(config_file)
33
- classifier_name = config.get('indicator','MODEL')
34
 
35
- logging.info("Loading indicator classifier")
36
  # we are using the pipeline as the model is multilabel and DocumentClassifier
37
  # from Haystack doesnt support multilabel
38
  # in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
@@ -51,7 +51,7 @@ def load_indicatorClassifier(config_file:str = None, classifier_name:str = None)
51
 
52
 
53
  @st.cache_data
54
- def indicator_classification(haystack_doc:pd.DataFrame,
55
  threshold:float = 0.5,
56
  classifier_model:pipeline= None
57
  )->Tuple[DataFrame,Series]:
@@ -74,14 +74,14 @@ def indicator_classification(haystack_doc:pd.DataFrame,
74
  x: Series object with the unique SDG covered in the document uploaded and
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
- logging.info("Working on Indicator Identification")
78
  haystack_doc['Indicator Label'] = 'NA'
79
  haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
80
 
81
  df1 = haystack_doc[haystack_doc['PA_check'] == True]
82
  df = haystack_doc[haystack_doc['PA_check'] == False]
83
  if not classifier_model:
84
- classifier_model = st.session_state['indicator_classifier']
85
 
86
  predictions = classifier_model(list(df1.text))
87
 
 
10
 
11
 
12
  @st.cache_resource
13
+ def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
14
  """
15
  loads the document classifier using haystack, where the name/path of model
16
  in HF-hub as string is used to fetch the model object.Either configfile or
 
30
  return
31
  else:
32
  config = getconfig(config_file)
33
+ classifier_name = config.get('vulnerability','MODEL')
34
 
35
+ logging.info("Loading vulnerability classifier")
36
  # we are using the pipeline as the model is multilabel and DocumentClassifier
37
  # from Haystack doesnt support multilabel
38
  # in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
 
51
 
52
 
53
  @st.cache_data
54
+ def vulnerability_classification(haystack_doc:pd.DataFrame,
55
  threshold:float = 0.5,
56
  classifier_model:pipeline= None
57
  )->Tuple[DataFrame,Series]:
 
74
  x: Series object with the unique SDG covered in the document uploaded and
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
+ logging.info("Working on vulnerability Identification")
78
  haystack_doc['Indicator Label'] = 'NA'
79
  haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
80
 
81
  df1 = haystack_doc[haystack_doc['PA_check'] == True]
82
  df = haystack_doc[haystack_doc['PA_check'] == False]
83
  if not classifier_model:
84
+ classifier_model = st.session_state['vulnerability_classifier']
85
 
86
  predictions = classifier_model(list(df1.text))
87