Shredder commited on
Commit
99cd595
1 Parent(s): 84442d0

Upload sus_fls.py

Browse files
Files changed (1) hide show
  1. sus_fls.py +52 -0
sus_fls.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer,pipeline
2
+ import torch
3
+ import nltk
4
+ from nltk.tokenize import sent_tokenize
5
+ from fin_readability_sustainability import BERTClass, do_predict
6
+ import pandas as pd
7
+ import en_core_web_sm
8
+
9
+ nltk.download('punkt')
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ #SUSTAINABILITY STARTS
13
+ tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
14
+ model_sustain = BERTClass(2, "sustanability")
15
+ model_sustain.to(device)
16
+ model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
17
+
18
+ def get_sustainability(text):
19
+ df = pd.DataFrame({'sentence':sent_tokenize(text)})
20
+ actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
21
+ highlight = []
22
+ for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
23
+ if prob>=4.384316:
24
+ highlight.append((sent, 'non-sustainable'))
25
+ elif prob<=1.423736:
26
+ highlight.append((sent, 'sustainable'))
27
+ else:
28
+ highlight.append((sent, '-'))
29
+ return highlight
30
+ #SUSTAINABILITY ENDS
31
+
32
+
33
+ ##Forward Looking Statement
34
+ nlp = en_core_web_sm.load()
35
+ def split_in_sentences(text):
36
+ doc = nlp(text)
37
+ return [str(sent).strip() for sent in doc.sents]
38
+ def make_spans(text,results):
39
+ results_list = []
40
+ for i in range(len(results)):
41
+ results_list.append(results[i]['label'])
42
+ facts_spans = []
43
+ facts_spans = list(zip(split_in_sentences(text),results_list))
44
+ return facts_spans
45
+
46
+ fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
47
+ def fls(text):
48
+ results = fls_model(split_in_sentences(text))
49
+ return make_spans(text,results)
50
+
51
+
52
+