Shredder commited on
Commit
10f176f
1 Parent(s): a35d8de

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from predict import run_prediction
2
+ from io import StringIO
3
+ import json
4
+ import gradio as gr
5
+ import spacy
6
+ from spacy import displacy
7
+ from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
8
+ import torch
9
+ import nltk
10
+ from nltk.tokenize import sent_tokenize
11
+ from fin_readability_sustainability import BERTClass, do_predict
12
+ import pandas as pd
13
+ import en_core_web_sm
14
+ from fincat_utils import extract_context_words
15
+ from fincat_utils import bert_embedding_extract
16
+ import pickle
17
+ lr_clf = pickle.load(open("lr_clf_FiNCAT.pickle",'rb'))
18
+
19
+ nlp = en_core_web_sm.load()
20
+ nltk.download('punkt')
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ #SUSTAINABILITY STARTS
24
+ tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
25
+ model_sustain = BERTClass(2, "sustanability")
26
+ model_sustain.to(device)
27
+ model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
28
+
29
+ def get_sustainability(text):
30
+ df = pd.DataFrame({'sentence':sent_tokenize(text)})
31
+ actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
32
+ highlight = []
33
+ for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
34
+ if prob>=4.384316:
35
+ highlight.append((sent, 'non-sustainable'))
36
+ elif prob<=1.423736:
37
+ highlight.append((sent, 'sustainable'))
38
+ else:
39
+ highlight.append((sent, '-'))
40
+ return highlight
41
+ #SUSTAINABILITY ENDS
42
+
43
+ #CLAIM STARTS
44
+ def score_fincat(txt):
45
+ li = []
46
+ highlight = []
47
+ txt = " " + txt + " "
48
+ k = ''
49
+ for word in txt.split():
50
+ if any(char.isdigit() for char in word):
51
+ if word[-1] in ['.', ',', ';', ":", "-", "!", "?", ")", '"', "'"]:
52
+ k = word[-1]
53
+ word = word[:-1]
54
+ st = txt.find(" " + word + k + " ")+1
55
+ k = ''
56
+ ed = st + len(word)
57
+ x = {'paragraph' : txt, 'offset_start':st, 'offset_end':ed}
58
+ context_text = extract_context_words(x)
59
+ features = bert_embedding_extract(context_text, word)
60
+ if(features[0]=='None'):
61
+ highlight.append(('None', ' '))
62
+ return highlight
63
+ prediction = lr_clf.predict(features.reshape(1, 768))
64
+ prediction_probability = '{:.4f}'.format(round(lr_clf.predict_proba(features.reshape(1, 768))[:,1][0], 4))
65
+ highlight.append((word, ' In-claim' if prediction==1 else 'Out-of-Claim'))
66
+ # li.append([word,' In-claim' if prediction==1 else 'Out-of-Claim', prediction_probability])
67
+ else:
68
+ highlight.append((word, ' '))
69
+ #headers = ['numeral', 'prediction', 'probability']
70
+ #dff = pd.DataFrame(li)
71
+ # dff.columns = headers
72
+ return highlight
73
+
74
+
75
+ ##Summarization
76
+ summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
77
+ def summarize_text(text):
78
+ resp = summarizer(text)
79
+ stext = resp[0]['summary_text']
80
+ return stext
81
+
82
+
83
+ def split_in_sentences(text):
84
+ doc = nlp(text)
85
+ return [str(sent).strip() for sent in doc.sents]
86
+ def make_spans(text,results):
87
+ results_list = []
88
+ for i in range(len(results)):
89
+ results_list.append(results[i]['label'])
90
+ facts_spans = []
91
+ facts_spans = list(zip(split_in_sentences(text),results_list))
92
+ return facts_spans
93
+ ##Forward Looking Statement
94
+ fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
95
+ def fls(text):
96
+ results = fls_model(split_in_sentences(text))
97
+ return make_spans(text,results)
98
+
99
+ ##Company Extraction
100
+ ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
101
+ def fin_ner(text):
102
+ replaced_spans = ner(text)
103
+ new_spans=[]
104
+ for item in replaced_spans:
105
+ item['entity']=item['entity_group']
106
+ del item['entity_group']
107
+ new_spans.append(item)
108
+ return {"text": text, "entities": new_spans}
109
+
110
+
111
+ #CUAD STARTS
112
+ def load_questions():
113
+ questions = []
114
+ with open('questions.txt') as f:
115
+ questions = f.readlines()
116
+ return questions
117
+
118
+
119
+ def load_questions_short():
120
+ questions_short = []
121
+ with open('questionshort.txt') as f:
122
+ questions_short = f.readlines()
123
+ return questions_short
124
+ questions = load_questions()
125
+ questions_short = load_questions_short()
126
+ def quad(query,file):
127
+ with open(file.name) as f:
128
+ paragraph = f.read()
129
+ questions = load_questions()
130
+ questions_short = load_questions_short()
131
+ if (not len(paragraph)==0) and not (len(query)==0):
132
+ print('getting predictions')
133
+ predictions = run_prediction([query], paragraph, 'marshmellow77/roberta-base-cuad',n_best_size=5)
134
+ answer = ""
135
+ if predictions['0'] == "":
136
+ answer = 'No answer found in document'
137
+ else:
138
+ with open("nbest.json") as jf:
139
+ data = json.load(jf)
140
+ for i in range(1):
141
+ raw_answer=data['0'][i]['text']
142
+ answer += f"{data['0'][i]['text']} -- \n"
143
+ answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
144
+ #summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
145
+ #resp = summarizer(answer)
146
+ #stext = resp[0]['summary_text']
147
+
148
+ # highlight,dff=score_fincat(answer)
149
+ return answer,summarize_text(answer),fin_ner(answer),score_fincat(answer),get_sustainability(answer),fls(answer)
150
+
151
+
152
+ # b6 = gr.Button("Get Sustainability")
153
+ #b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText())
154
+
155
+
156
+ #iface = gr.Interface(fn=get_sustainability, inputs="textbox", title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never")
157
+ #iface.launch()
158
+
159
+ iface = gr.Interface(fn=quad, inputs=[gr.Dropdown(choices=questions,label='SEARCH QUERY'),gr.inputs.File(label='TXT FILE')], title="CONBERT",description="SUSTAINABILITY TOOL",article='Article', outputs=[gr.outputs.Textbox(label='Answer'),gr.outputs.Textbox(label='Summary'),gr.HighlightedText(label='NER'),gr.HighlightedText(label='CLAIM'),gr.HighlightedText(label='SUSTAINABILITY'),gr.HighlightedText(label='FLS')], allow_flagging="never")
160
+
161
+
162
+ iface.launch()