vulnerability_2_1 / utils /conditional_classifier.py
leavoigt's picture
Upload 11 files
579b090
raw
history blame
3.97 kB
from typing import List, Tuple
from typing_extensions import Literal
import logging
import pandas as pd
from pandas import DataFrame, Series
from utils.config import getconfig
from utils.preprocessing import processingpipeline
import streamlit as st
from transformers import pipeline
@st.cache_resource
def load_conditionalClassifier(config_file:str = None, classifier_name:str = None):
"""
loads the document classifier using haystack, where the name/path of model
in HF-hub as string is used to fetch the model object.Either configfile or
model should be passed.
1. https://docs.haystack.deepset.ai/reference/document-classifier-api
2. https://docs.haystack.deepset.ai/docs/document_classifier
Params
--------
config_file: config file path from which to read the model name
classifier_name: if modelname is passed, it takes a priority if not \
found then will look for configfile, else raise error.
Return: document classifier model
"""
if not classifier_name:
if not config_file:
logging.warning("Pass either model name or config file")
return
else:
config = getconfig(config_file)
classifier_name = config.get('conditional','MODEL')
logging.info("Loading conditional classifier")
doc_classifier = pipeline("text-classification",
model=classifier_name,
top_k =1)
return doc_classifier
@st.cache_data
def conditional_classification(haystack_doc:pd.DataFrame,
threshold:float = 0.8,
classifier_model:pipeline= None
)->Tuple[DataFrame,Series]:
"""
Text-Classification on the list of texts provided. Classifier provides the
most appropriate label for each text. It informs if paragraph contains any
netzero information or not.
Params
---------
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
contains the list of paragraphs in different format,here the list of
Haystack Documents is used.
threshold: threshold value for the model to keep the results from classifier
classifiermodel: you can pass the classifier model directly,which takes priority
however if not then looks for model in streamlit session.
In case of streamlit avoid passing the model directly.
Returns
----------
df: Dataframe
"""
logging.info("Working on Conditionality Identification")
haystack_doc['Conditional Label'] = 'NA'
haystack_doc['Conditional Score'] = 0.0
haystack_doc['cond_check'] = False
haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
#df1 = haystack_doc[haystack_doc['PA_check'] == True]
#df = haystack_doc[haystack_doc['PA_check'] == False]
haystack_doc['cond_check'] = haystack_doc.apply(lambda x: True if (
(x['Target Label'] == 'TARGET') | (x['PA_check'] == True)) else
False, axis=1)
# we apply Netzero to only paragraphs which are classified as 'Target' related
temp = haystack_doc[haystack_doc['cond_check'] == True]
temp = temp.reset_index(drop=True)
df = haystack_doc[haystack_doc['cond_check'] == False]
df = df.reset_index(drop=True)
if not classifier_model:
classifier_model = st.session_state['conditional_classifier']
results = classifier_model(list(temp.text))
labels_= [(l[0]['label'],l[0]['score']) for l in results]
temp['Conditional Label'],temp['Conditional Score'] = zip(*labels_)
# temp[' Label'] = temp['Netzero Label'].apply(lambda x: _lab_dict[x])
# merging Target with Non Target dataframe
df = pd.concat([df,temp])
df = df.drop(columns = ['cond_check','PA_check'])
df = df.reset_index(drop =True)
df.index += 1
return df