Spaces:
Runtime error
Runtime error
import nltk | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
import pickle | |
from keybert import KeyBERT | |
from nltk.util import everygrams | |
from nltk.corpus import stopwords | |
from nltk.tokenize import sent_tokenize | |
from fincat_utils import extract_context_words | |
from fincat_utils import bert_embedding_extract | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification, pipeline, AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForSeq2SeqLM, AutoModel, RobertaModel, RobertaTokenizer | |
import gradio as gr | |
import pandas as pd | |
from fin_readability_sustainability import BERTClass, do_predict | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
kw_model = KeyBERT(model='all-mpnet-base-v2') | |
#ESG | |
finbert_esg = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-esg',num_labels=4) | |
tokenizer_esg = BertTokenizer.from_pretrained('yiyanghkust/finbert-esg') | |
nlp_esg = pipeline("text-classification", model=finbert_esg, tokenizer=tokenizer_esg) | |
#FLS | |
finbert_fls = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-fls',num_labels=3) | |
tokenizer_fls = BertTokenizer.from_pretrained('yiyanghkust/finbert-fls') | |
nlp_fls = pipeline("text-classification", model=finbert_fls, tokenizer=tokenizer_fls) | |
#FinCAT - Claim Detection | |
lr_clf_claim = pickle.load(open("lr_clf_FiNCAT.pickle",'rb')) | |
#Sustainability | |
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base') | |
model_sustain = BERTClass(2, "sustanability") | |
model_sustain.to(device) | |
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict']) | |
#Readability | |
tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert') | |
model_read = BERTClass(2, "readability") | |
model_read.to(device) | |
model_read.load_state_dict(torch.load('readability_model.bin', map_location=device)['model_state_dict']) | |
#Sentiment | |
model_senti1 = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone',num_labels=3) | |
tokenizer_senti1 = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone') | |
senti1 = pipeline("sentiment-analysis", model=model_senti1, tokenizer=tokenizer_senti1) | |
model_senti2 = AutoModelForSequenceClassification.from_pretrained("mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis") | |
tokenizer_senti2 = AutoTokenizer.from_pretrained("mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis") | |
senti2 = TextClassificationPipeline(model=model_senti2, tokenizer=tokenizer_senti2) | |
#Summarization | |
model_finsum = AutoModelForSeq2SeqLM.from_pretrained("human-centered-summarization/financial-summarization-pegasus") | |
tokenizer_finsum = AutoTokenizer.from_pretrained("human-centered-summarization/financial-summarization-pegasus") | |
#Hypernym Detection | |
model_finlipi = SentenceTransformer('sohomghosh/LIPI_FinSim3_Hypernym') | |
hypernyms = ['Bonds','Forward','Funds', 'Future', 'MMIs','Option', 'Stocks', 'Swap', 'Equity Index', 'Credit Index', 'Securities restrictions', 'Parametric schedules', 'Debt pricing and yields', 'Credit Events','Stock Corporation', 'Central Securities Depository', 'Regulatory Agency'] | |
hyp_di = { v:k for v, k in enumerate(hypernyms)} | |
hypernyms_embeddings = model_finlipi.encode(hypernyms) | |
#ESG | |
def esg(text): | |
sents = sent_tokenize(text) | |
results = nlp_esg(sents) | |
highlight = [(text,i['label']) for text,i in zip(sents,results)] | |
return highlight | |
#FLS | |
def fls(text): | |
sents = sent_tokenize(text) | |
results = nlp_fls(sents) | |
highlight = [(text,i['label']) for text,i in zip(sents,results)] | |
return highlight | |
#Sentiment | |
def getfinsenti(text): | |
highlight = [] | |
for text in sent_tokenize(text): | |
senti1_output = senti1(text)[0] | |
senti2_output = senti2(text)[0] | |
if senti1_output['score'] >= senti2_output['score']: | |
label = senti1_output['label'] | |
score = round(senti1_output['score'], 4) | |
else: | |
label = senti2_output['label'] | |
score = round(senti2_output['score'], 4) | |
highlight.append((text, label.strip().lower())) | |
return highlight | |
#Summarization | |
def summarize_pega(text): | |
input_ids = tokenizer_finsum(text, return_tensors="pt").input_ids | |
output = model_finsum.generate( | |
input_ids, | |
max_length=32, | |
num_beams=5, | |
early_stopping=True | |
) | |
summary = str(tokenizer_finsum.decode(output[0], skip_special_tokens=True)) | |
return summary | |
#Hypernym Detection | |
def get_hyp(words, th=0.85): | |
queries = [wd.strip() for wd in words.split(",")] | |
highlight = [] | |
if len(queries)>0: | |
query_embeddings = model_finlipi.encode(queries) | |
cos_scores = util.pytorch_cos_sim(query_embeddings, hypernyms_embeddings) | |
ans = torch.max(cos_scores, dim=1) | |
for sim,ind,query in zip(ans.values, ans.indices, queries): | |
if query.strip()!="": | |
if sim.item()>th: | |
highlight.append((query, hyp_di[ind.item()])) | |
else: | |
highlight.append((query, 'no hypernym found')) | |
return highlight | |
#FinCAT - Claim Detection | |
def score_fincat(txt): | |
''' | |
Extracts numerals from financial texts and checks if they are in-claim or out-of claim | |
Parameters: | |
txt (str): Financial Text. This is to be given as input. Numerals present in this text will be evaluated. | |
Returns: | |
highlight (list): A list each element of which is a tuple. Each tuple has two elements i) word ii) whether the word is in-claim or out-of-claim. | |
''' | |
#li = [] | |
highlight = [] | |
txt = " " + txt + " " | |
k = '' | |
for word in txt.split(): | |
if any(char.isdigit() for char in word): | |
if word[-1] in ['.', ',', ';', ":", "-", "!", "?", ")", '"', "'"]: | |
k = word[-1] | |
word = word[:-1] | |
st = txt.index(" " + word + k + " ")+1 | |
k = '' | |
ed = st + len(word) | |
x = {'paragraph' : txt, 'offset_start':st, 'offset_end':ed} | |
context_text = extract_context_words(x) | |
features = bert_embedding_extract(context_text, word) | |
prediction = lr_clf_claim.predict(features.reshape(1, 768)) | |
highlight.append((word, 'In-claim' if prediction==1 else 'Out-of-Claim')) | |
else: | |
highlight.append((word, '')) | |
headers = ['numeral', 'prediction', 'probability'] | |
return highlight | |
#Readability | |
def get_readability(text): | |
df = pd.DataFrame({'sentence':sent_tokenize(text)}) | |
actual_predictions_read = do_predict(model_read, tokenizer_read, df) | |
highlight = [(sent, 'readable') if i==1 else (sent, 'non-readable') for sent,i in zip(df['sentence'].values, actual_predictions_read[0])] | |
return highlight | |
#Sustainability | |
def get_sustainability(text): | |
df = pd.DataFrame({'sentence':sent_tokenize(text)}) | |
actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df) | |
highlight = [] | |
for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]): | |
if prob>=2.8:#4.384316: | |
highlight.append((sent, 'non-sustainable')) | |
elif prob<=1.423736: | |
highlight.append((sent, 'sustainable')) | |
else: | |
highlight.append((sent, '-')) | |
return highlight | |
#keywords | |
def get_keywords(text): | |
keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 3), stop_words='english', highlight=False, top_n=3) | |
keywords_list= list(dict(keywords).keys()) | |
return ",".join([i.strip() for i in keywords_list]) | |
#examples | |
def set_example_text(example_text): | |
return gr.Textbox.update(value=example_text[0]) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# **Financial Language Understandability Enhancement Toolkit (FLUEnT)**") | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.inputs.Textbox(label="Enter financial text here", lines=6, placeholder="Enter Financial Text here...") | |
b_hyp_th = gr.inputs.Slider(minimum=0, maximum=1, step=0.01, label="Detect hypernyms with confidence of") | |
with gr.Row(): | |
b1 = gr.Button("Get Keywords For Hypernym Detection") | |
with gr.Row(): | |
jargons = gr.Textbox(label="Enter words for Hypernyms Detection separated by comma") | |
b1.click(get_keywords, inputs = text, outputs=jargons) | |
example_text = gr.Dataset(components=[text], samples=[["Markets are falling."], ["Exchanges the coupon on a bond for LIBOR plus a spread."], ["We follow a carbon neutrality strategy, seek to use resources efficiently and work to deliver sustainable value for society"], ["NGOs have been instrumental in shaping the economy"], ["We expect to boost our sales by 80% this year by using eco-friendly products."], ["We will continue to evaluate the need for an employee allowance as it hinders growth."],["As an example, in the calculation as of the end of 2020, carbon emissions of an issuer relate to 2019, whereas market capitalization is shown as of the end of 2020."], ["In addition to the impacts from the merger, insurance income increased $121 million due to strong production and acquisitions."],["In the year 2021, the markets were bullish. We expect to boost our sales by 80% this year by using eco-friendly products."], ["Noninterest income increased $1.7 billion due primarily to the Merger and higher residential mortgage income as a result of the lower rate environment driving mortgage production through refinance activity, partially offset by lower residential mortgage servicing income driven by higher prepayment and an MSR fair value adjustment in 2020. This year it will increase to $3M."]]) | |
example_text.click(fn=set_example_text, | |
inputs=example_text, | |
outputs=example_text.components) | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Hypernyms & Claims"): | |
with gr.Row(): | |
b_hyp = gr.Button("Get Hypernyms") | |
b_hyp.click(get_hyp, inputs = [jargons, b_hyp_th], outputs = gr.HighlightedText()) | |
with gr.Row(): | |
b3 = gr.Button("Get Claims") | |
b3.click(score_fincat, inputs = text, outputs = gr.HighlightedText().style(color_map={"In-claim": "red", "Out-of-Claim": "green"})) | |
with gr.TabItem("Summary & Sentiment"): | |
with gr.Row(): | |
b2 = gr.Button("Get Summary") | |
b2.click(summarize_pega, inputs = text, outputs = gr.Textbox(label="Summary")) | |
with gr.Row(): | |
b4 = gr.Button("Get Sentiment") | |
b4.click(getfinsenti, inputs = text, outputs = gr.HighlightedText().style(color_map={"negative": "red", "neutral":"blue", "positive": "green"})) | |
with gr.TabItem("Readability & Sustainability"): | |
with gr.Row(): | |
b5 = gr.Button("Get Readability") | |
b5.click(get_readability, inputs = text, outputs = gr.HighlightedText().style(color_map={"non-readable": "red", "readable": "green"})) | |
with gr.Row(): | |
b6 = gr.Button("Get Sustainability") | |
b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText().style(color_map={"non-sustainable": "red", "-":"blue", "sustainable": "green"})) | |
with gr.TabItem("ESG & FLS"): | |
with gr.Row(): | |
b6 = gr.Button("Get Environmental, Social & Gov.(ESG)") | |
b6.click(esg, inputs = text, outputs = gr.HighlightedText().style(color_map={"Governance": "red", "Social":"blue", "Environmental": "green", "None":"yellow"})) | |
with gr.Row(): | |
b6 = gr.Button("Get Forward Looking Statements(FLS)") | |
b6.click(fls, inputs = text, outputs = gr.HighlightedText().style(color_map={"Non-specific FLS": "red", "Not-FLS":"blue", "Specific-FLS": "green"})) | |
gr.Markdown("How to use? [link](https://youtu.be/Bp8Ij5GQ59I), Warning: User discretion is advised., Colab Notebook [link](https://colab.research.google.com/drive/1-KBBKByCU2bkyAUDwW-h6QCSqWI8z127?usp=sharing), Citing Paper [link](https://easychair.org/publications/preprint/cWW5)") | |
demo.launch() |