vulnerability_2_1 / utils /target_classifier.py
leavoigt's picture
Update utils/target_classifier.py
3efd370
raw
history blame
No virus
4.97 kB
# from typing import List, Tuple
# from typing_extensions import Literal
# import logging
# import pandas as pd
# from pandas import DataFrame, Series
# from utils.config import getconfig
# from utils.preprocessing import processingpipeline
# import streamlit as st
# from transformers import pipeline
# ## Labels dictionary ###
# _lab_dict = {
# '0':'NO',
# '1':'YES',
# }
# def get_target_labels(preds):
# """
# Function that takes the numerical predictions as an input and returns a list of the labels.
# """
# # Get label names
# preds_list = preds.tolist()
# predictions_names=[]
# # loop through each prediction
# for ele in preds_list:
# # see if there is a value 1 and retrieve index
# try:
# index_of_one = ele.index(1)
# except ValueError:
# index_of_one = "NA"
# # Retrieve the name of the label (if no prediction made = NA)
# if index_of_one != "NA":
# name = label_dict[index_of_one]
# else:
# name = "Other"
# # Append name to list
# predictions_names.append(name)
# return predictions_names
# @st.cache_resource
# def load_targetClassifier(config_file:str = None, classifier_name:str = None):
# """
# loads the document classifier using haystack, where the name/path of model
# in HF-hub as string is used to fetch the model object.Either configfile or
# model should be passed.
# 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
# 2. https://docs.haystack.deepset.ai/docs/document_classifier
# Params
# --------
# config_file: config file path from which to read the model name
# classifier_name: if modelname is passed, it takes a priority if not \
# found then will look for configfile, else raise error.
# Return: document classifier model
# """
# if not classifier_name:
# if not config_file:
# logging.warning("Pass either model name or config file")
# return
# else:
# config = getconfig(config_file)
# classifier_name = config.get('target','MODEL')
# logging.info("Loading classifier")
# doc_classifier = pipeline("text-classification",
# model=classifier_name,
# top_k =1)
# return doc_classifier
# @st.cache_data
# def target_classification(haystack_doc:pd.DataFrame,
# threshold:float = 0.5,
# classifier_model:pipeline= None
# )->Tuple[DataFrame,Series]:
# """
# Text-Classification on the list of texts provided. Classifier provides the
# most appropriate label for each text. There labels indicate whether the paragraph
# references a specific action, target or measure in the paragraph.
# ---------
# haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
# contains the list of paragraphs in different format,here the list of
# Haystack Documents is used.
# threshold: threshold value for the model to keep the results from classifier
# classifiermodel: you can pass the classifier model directly,which takes priority
# however if not then looks for model in streamlit session.
# In case of streamlit avoid passing the model directly.
# Returns
# ----------
# df: Dataframe with two columns['SDG:int', 'text']
# x: Series object with the unique SDG covered in the document uploaded and
# the number of times it is covered/discussed/count_of_paragraphs.
# """
# logging.info("Working on target/action identification")
# haystack_doc['Vulnerability Label'] = 'NA'
# if not classifier_model:
# classifier_model = st.session_state['target_classifier']
# # Get predictions
# predictions = classifier_model(list(haystack_doc.text))
# # Get labels for predictions
# pred_labels = getlabels(predictions)
# # Save labels
# haystack_doc['Target Label'] = pred_labels
# # logging.info("Working on action/target extraction")
# # if not classifier_model:
# # classifier_model = st.session_state['target_classifier']
# # results = classifier_model(list(haystack_doc.text))
# # labels_= [(l[0]['label'],
# # l[0]['score']) for l in results]
# # df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
# # df = pd.concat([haystack_doc,df1],axis=1)
# # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
# # df['Target Score'] = df['Target Score'].round(2)
# # df.index += 1
# # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
# return haystack_doc