leavoigt commited on
Commit
2df11a9
1 Parent(s): 8a9ef0d

Delete utils/conditional_classifier.py

Browse files
Files changed (1) hide show
  1. utils/conditional_classifier.py +0 -95
utils/conditional_classifier.py DELETED
@@ -1,95 +0,0 @@
1
- from typing import List, Tuple
2
- from typing_extensions import Literal
3
- import logging
4
- import pandas as pd
5
- from pandas import DataFrame, Series
6
- from utils.config import getconfig
7
- from utils.preprocessing import processingpipeline
8
- import streamlit as st
9
- from transformers import pipeline
10
-
11
-
12
- @st.cache_resource
13
- def load_conditionalClassifier(config_file:str = None, classifier_name:str = None):
14
- """
15
- loads the document classifier using haystack, where the name/path of model
16
- in HF-hub as string is used to fetch the model object.Either configfile or
17
- model should be passed.
18
- 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
19
- 2. https://docs.haystack.deepset.ai/docs/document_classifier
20
- Params
21
- --------
22
- config_file: config file path from which to read the model name
23
- classifier_name: if modelname is passed, it takes a priority if not \
24
- found then will look for configfile, else raise error.
25
- Return: document classifier model
26
- """
27
- if not classifier_name:
28
- if not config_file:
29
- logging.warning("Pass either model name or config file")
30
- return
31
- else:
32
- config = getconfig(config_file)
33
- classifier_name = config.get('conditional','MODEL')
34
-
35
- logging.info("Loading conditional classifier")
36
- doc_classifier = pipeline("text-classification",
37
- model=classifier_name,
38
- top_k =1)
39
-
40
- return doc_classifier
41
-
42
-
43
- @st.cache_data
44
- def conditional_classification(haystack_doc:pd.DataFrame,
45
- threshold:float = 0.8,
46
- classifier_model:pipeline= None
47
- )->Tuple[DataFrame,Series]:
48
- """
49
- Text-Classification on the list of texts provided. Classifier provides the
50
- most appropriate label for each text. It informs if paragraph contains any
51
- netzero information or not.
52
- Params
53
- ---------
54
- haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
55
- contains the list of paragraphs in different format,here the list of
56
- Haystack Documents is used.
57
- threshold: threshold value for the model to keep the results from classifier
58
- classifiermodel: you can pass the classifier model directly,which takes priority
59
- however if not then looks for model in streamlit session.
60
- In case of streamlit avoid passing the model directly.
61
- Returns
62
- ----------
63
- df: Dataframe
64
- """
65
- logging.info("Working on Conditionality Identification")
66
- haystack_doc['Conditional Label'] = 'NA'
67
- haystack_doc['Conditional Score'] = 0.0
68
- haystack_doc['cond_check'] = False
69
- haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
70
-
71
- #df1 = haystack_doc[haystack_doc['PA_check'] == True]
72
- #df = haystack_doc[haystack_doc['PA_check'] == False]
73
- haystack_doc['cond_check'] = haystack_doc.apply(lambda x: True if (
74
- (x['Target Label'] == 'TARGET') | (x['PA_check'] == True)) else
75
- False, axis=1)
76
- # we apply Netzero to only paragraphs which are classified as 'Target' related
77
- temp = haystack_doc[haystack_doc['cond_check'] == True]
78
- temp = temp.reset_index(drop=True)
79
- df = haystack_doc[haystack_doc['cond_check'] == False]
80
- df = df.reset_index(drop=True)
81
-
82
- if not classifier_model:
83
- classifier_model = st.session_state['conditional_classifier']
84
-
85
- results = classifier_model(list(temp.text))
86
- labels_= [(l[0]['label'],l[0]['score']) for l in results]
87
- temp['Conditional Label'],temp['Conditional Score'] = zip(*labels_)
88
- # temp[' Label'] = temp['Netzero Label'].apply(lambda x: _lab_dict[x])
89
- # merging Target with Non Target dataframe
90
- df = pd.concat([df,temp])
91
- df = df.drop(columns = ['cond_check','PA_check'])
92
- df = df.reset_index(drop =True)
93
- df.index += 1
94
-
95
- return df