Shredder commited on
Commit
c2f789d
1 Parent(s): 211535a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer
3
+ import torch
4
+ from fin_readability_sustainability import BERTClass, do_predict
5
+
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
11
+ model_sustain = BERTClass(2, "sustanability")
12
+ model_sustain.to(device)
13
+ model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
14
+
15
+
16
+ from nltk.tokenize import sent_tokenize
17
+ def get_sustainability(text):
18
+ df = pd.DataFrame({'sentence':sent_tokenize(text)})
19
+ actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
20
+ highlight = []
21
+ for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
22
+ if prob>=4.384316:
23
+ highlight.append((sent, 'non-sustainable'))
24
+ elif prob<=1.423736:
25
+ highlight.append((sent, 'sustainable'))
26
+ else:
27
+ highlight.append((sent, '-'))
28
+ return highlight
29
+
30
+
31
+
32
+ b6 = gr.Button("Get Sustainability")
33
+ b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText())
34
+
35
+
36
+ iface = gr.Interface(fn=get_sustainability, inputs=gr.inputs.Textbox(lines=5, placeholder="Enter Financial Text here..."), title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never")
37
+ iface.launch()