Update utils/vulnerability_classifier.py
Browse files
utils/vulnerability_classifier.py
CHANGED
@@ -75,6 +75,9 @@ def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = N
|
|
75 |
found then will look for configfile, else raise error.
|
76 |
Return: document classifier model
|
77 |
"""
|
|
|
|
|
|
|
78 |
if not classifier_name:
|
79 |
if not config_file:
|
80 |
logging.warning("Pass either model name or config file")
|
@@ -84,6 +87,7 @@ def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = N
|
|
84 |
classifier_name = config.get('vulnerability','MODEL')
|
85 |
|
86 |
logging.info("Loading vulnerability classifier")
|
|
|
87 |
# we are using the pipeline as the model is multilabel and DocumentClassifier
|
88 |
# from Haystack doesnt support multilabel
|
89 |
# in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
|
@@ -93,7 +97,7 @@ def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = N
|
|
93 |
# task="text-classification",
|
94 |
# top_k = None)
|
95 |
|
96 |
-
#
|
97 |
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_multilabel")
|
98 |
|
99 |
# doc_classifier = pipeline("text-classification",
|
@@ -112,8 +116,7 @@ def vulnerability_classification(haystack_doc:pd.DataFrame,
|
|
112 |
"""
|
113 |
Text-Classification on the list of texts provided. Classifier provides the
|
114 |
most appropriate label for each text. these labels are in terms of if text
|
115 |
-
|
116 |
-
Params
|
117 |
---------
|
118 |
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
|
119 |
contains the list of paragraphs in different format,here the list of
|
|
|
75 |
found then will look for configfile, else raise error.
|
76 |
Return: document classifier model
|
77 |
"""
|
78 |
+
|
79 |
+
# If no classifier given
|
80 |
+
|
81 |
if not classifier_name:
|
82 |
if not config_file:
|
83 |
logging.warning("Pass either model name or config file")
|
|
|
87 |
classifier_name = config.get('vulnerability','MODEL')
|
88 |
|
89 |
logging.info("Loading vulnerability classifier")
|
90 |
+
|
91 |
# we are using the pipeline as the model is multilabel and DocumentClassifier
|
92 |
# from Haystack doesnt support multilabel
|
93 |
# in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
|
|
|
97 |
# task="text-classification",
|
98 |
# top_k = None)
|
99 |
|
100 |
+
# Download model from HF Hub
|
101 |
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_multilabel")
|
102 |
|
103 |
# doc_classifier = pipeline("text-classification",
|
|
|
116 |
"""
|
117 |
Text-Classification on the list of texts provided. Classifier provides the
|
118 |
most appropriate label for each text. these labels are in terms of if text
|
119 |
+
reference a group in a vulnerable situation.
|
|
|
120 |
---------
|
121 |
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
|
122 |
contains the list of paragraphs in different format,here the list of
|