TrungTech's picture
Upload 6 files
f0d3a66
import pickle
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import BertTokenizer, BertForSequenceClassification, pipeline, AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForSeq2SeqLM, AutoModel, RobertaModel, RobertaTokenizer
from sentence_transformers import SentenceTransformer
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
#import lightgbm
#lr_clf_finbert = pickle.load(open("lr_clf_finread_new.pkl",'rb'))
tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_read = BERTClass(2, "readability")
model_read.to(device)
model_read.load_state_dict(torch.load('readability_model.bin', map_location=device)['model_state_dict'])
def get_readability(text):
df = pd.DataFrame({'sentence':[text]})
actual_predictions_read = do_predict(model_read, tokenizer_read, df)
score = round(actual_predictions_read[1][0], 4)
return score
# Reference : https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base
tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
def paraphrase(
question,
num_beams=5,
num_beam_groups=5,
num_return_sequences=5,
repetition_penalty=10.0,
diversity_penalty=3.0,
no_repeat_ngram_size=2,
temperature=0.7,
max_length=128
):
input_ids = tokenizer(
f'paraphrase: {question}',
return_tensors="pt", padding="longest",
max_length=max_length,
truncation=True,
).input_ids
outputs = model.generate(
input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beams, num_beam_groups=num_beam_groups,
max_length=max_length, diversity_penalty=diversity_penalty
)
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return res
def get_most_readable_paraphrse(text):
li_paraphrases = paraphrase(text)
li_paraphrases.append(text)
best = li_paraphrases[0]
score_max = get_readability(best)
for i in range(1,len(li_paraphrases)):
curr = li_paraphrases[i]
score = get_readability(curr)
if score > score_max:
best = curr
score_max = score
if best!=text and score_max>.6:
ans = "The most redable version of text that I can think of is:\n" + best
else:
"Sorry! I am not confident. As per my best knowledge, you already have the most readable version of the text!"
return ans
def set_example_text(example_text):
return gr.Textbox.update(value=example_text[0])
with gr.Blocks() as demo:
gr.Markdown(
"""
# FinLanSer
Financial Language Simplifier
""")
text = gr.Textbox(label="Enter text you want to simply (make more readable)")
greet_btn = gr.Button("Simplify/Make Readable")
output = gr.Textbox(label="Output Box")
greet_btn.click(fn=get_most_readable_paraphrse, inputs=text, outputs=output, api_name="get_most_raedable_paraphrse")
example_text = gr.Dataset(components=[text], samples=[['Legally assured line of credit with a bank'], ['A mutual fund is a type of financial vehicle made up of a pool of money collected from many investors to invest in securities like stocks, bonds, money market instruments']])
example_text.click(fn=set_example_text, inputs=example_text,outputs=example_text.components)
demo.launch()