ppsingh commited on
Commit
3e0be2b
·
1 Parent(s): ae31548

Update utils/ghg_classifier.py

Browse files
Files changed (1) hide show
  1. utils/ghg_classifier.py +89 -90
utils/ghg_classifier.py CHANGED
@@ -1,90 +1,89 @@
1
- from haystack.nodes import TransformersDocumentClassifier
2
- from haystack.schema import Document
3
- from typing import List, Tuple
4
- from typing_extensions import Literal
5
- import logging
6
- import pandas as pd
7
- from pandas import DataFrame, Series
8
- from utils.config import getconfig
9
- from utils.preprocessing import processingpipeline
10
- import streamlit as st
11
- from transformers import pipeline
12
-
13
- # Labels dictionary ###
14
- _lab_dict = {
15
- 'NEGATIVE':'NO GHG TARGET',
16
- 'TARGET':'GHG TARGET',
17
- }
18
-
19
- @st.cache_resource
20
- def load_ghgClassifier(config_file:str = None, classifier_name:str = None):
21
- """
22
- loads the document classifier using haystack, where the name/path of model
23
- in HF-hub as string is used to fetch the model object.Either configfile or
24
- model should be passed.
25
- 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
26
- 2. https://docs.haystack.deepset.ai/docs/document_classifier
27
- Params
28
- --------
29
- config_file: config file path from which to read the model name
30
- classifier_name: if modelname is passed, it takes a priority if not \
31
- found then will look for configfile, else raise error.
32
- Return: document classifier model
33
- """
34
- if not classifier_name:
35
- if not config_file:
36
- logging.warning("Pass either model name or config file")
37
- return
38
- else:
39
- config = getconfig(config_file)
40
- classifier_name = config.get('ghg','MODEL')
41
-
42
- logging.info("Loading ghg classifier")
43
- doc_classifier = pipeline("text-classification",
44
- model=classifier_name,
45
- top_k =1)
46
-
47
- return doc_classifier
48
-
49
-
50
- @st.cache_data
51
- def ghg_classification(haystack_doc:pd.DataFrame,
52
- threshold:float = 0.5,
53
- classifier_model:pipeline= None
54
- )->Tuple[DataFrame,Series]:
55
- """
56
- Text-Classification on the list of texts provided. Classifier provides the
57
- most appropriate label for each text. these labels are in terms of if text
58
- belongs to which particular Sustainable Devleopment Goal (SDG).
59
- Params
60
- ---------
61
- haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
62
- contains the list of paragraphs in different format,here the list of
63
- Haystack Documents is used.
64
- threshold: threshold value for the model to keep the results from classifier
65
- classifiermodel: you can pass the classifier model directly,which takes priority
66
- however if not then looks for model in streamlit session.
67
- In case of streamlit avoid passing the model directly.
68
- Returns
69
- ----------
70
- df: Dataframe with two columns['SDG:int', 'text']
71
- x: Series object with the unique SDG covered in the document uploaded and
72
- the number of times it is covered/discussed/count_of_paragraphs.
73
- """
74
- logging.info("Working on GHG Extraction")
75
- haystack_doc['GHG Label'] = 'NA'
76
- haystack_doc['GHG Score'] = 'NA'
77
- temp = haystack_doc[haystack_doc['Target Label'] == 'TARGET']
78
- df = haystack_doc[haystack_doc['Target Label'] == 'NEGATIVE']
79
-
80
- if not classifier_model:
81
- classifier_model = st.session_state['ghg_classifier']
82
-
83
- results = classifier_model(list(temp.text))
84
- labels_= [(l[0]['label'],l[0]['score']) for l in results]
85
- temp['GHG Label'],temp['GHG Score'] = zip(*labels_)
86
- df = pd.concat([df,temp])
87
- df = df.reset_index(drop =True)
88
- df.index += 1
89
-
90
- return df
 
1
+ from haystack.schema import Document
2
+ from typing import List, Tuple
3
+ from typing_extensions import Literal
4
+ import logging
5
+ import pandas as pd
6
+ from pandas import DataFrame, Series
7
+ from utils.config import getconfig
8
+ from utils.preprocessing import processingpipeline
9
+ import streamlit as st
10
+ from transformers import pipeline
11
+
12
+ # Labels dictionary ###
13
+ _lab_dict = {
14
+ 'NEGATIVE':'NO GHG TARGET',
15
+ 'TARGET':'GHG TARGET',
16
+ }
17
+
18
+ @st.cache_resource
19
+ def load_ghgClassifier(config_file:str = None, classifier_name:str = None):
20
+ """
21
+ loads the document classifier using haystack, where the name/path of model
22
+ in HF-hub as string is used to fetch the model object.Either configfile or
23
+ model should be passed.
24
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
25
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
26
+ Params
27
+ --------
28
+ config_file: config file path from which to read the model name
29
+ classifier_name: if modelname is passed, it takes a priority if not \
30
+ found then will look for configfile, else raise error.
31
+ Return: document classifier model
32
+ """
33
+ if not classifier_name:
34
+ if not config_file:
35
+ logging.warning("Pass either model name or config file")
36
+ return
37
+ else:
38
+ config = getconfig(config_file)
39
+ classifier_name = config.get('ghg','MODEL')
40
+
41
+ logging.info("Loading ghg classifier")
42
+ doc_classifier = pipeline("text-classification",
43
+ model=classifier_name,
44
+ top_k =1)
45
+
46
+ return doc_classifier
47
+
48
+
49
+ @st.cache_data
50
+ def ghg_classification(haystack_doc:pd.DataFrame,
51
+ threshold:float = 0.5,
52
+ classifier_model:pipeline= None
53
+ )->Tuple[DataFrame,Series]:
54
+ """
55
+ Text-Classification on the list of texts provided. Classifier provides the
56
+ most appropriate label for each text. these labels are in terms of if text
57
+ belongs to which particular Sustainable Devleopment Goal (SDG).
58
+ Params
59
+ ---------
60
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
61
+ contains the list of paragraphs in different format,here the list of
62
+ Haystack Documents is used.
63
+ threshold: threshold value for the model to keep the results from classifier
64
+ classifiermodel: you can pass the classifier model directly,which takes priority
65
+ however if not then looks for model in streamlit session.
66
+ In case of streamlit avoid passing the model directly.
67
+ Returns
68
+ ----------
69
+ df: Dataframe with two columns['SDG:int', 'text']
70
+ x: Series object with the unique SDG covered in the document uploaded and
71
+ the number of times it is covered/discussed/count_of_paragraphs.
72
+ """
73
+ logging.info("Working on GHG Extraction")
74
+ haystack_doc['GHG Label'] = 'NA'
75
+ haystack_doc['GHG Score'] = 'NA'
76
+ temp = haystack_doc[haystack_doc['Target Label'] == 'TARGET']
77
+ df = haystack_doc[haystack_doc['Target Label'] == 'NEGATIVE']
78
+
79
+ if not classifier_model:
80
+ classifier_model = st.session_state['ghg_classifier']
81
+
82
+ results = classifier_model(list(temp.text))
83
+ labels_= [(l[0]['label'],l[0]['score']) for l in results]
84
+ temp['GHG Label'],temp['GHG Score'] = zip(*labels_)
85
+ df = pd.concat([df,temp])
86
+ df = df.reset_index(drop =True)
87
+ df.index += 1
88
+
89
+ return df