sohomghosh commited on
Commit
13a47a2
1 Parent(s): 8902661

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from sentence_transformers import SentenceTransformer
5
+ import lightgbm
6
+ lr_clf_finbert = pickle.load(open("lr_clf_finread_new.pkl",'rb'))
7
+ model_read = SentenceTransformer('ProsusAI/finbert')
8
+
9
+ def get_readability(text):
10
+ emd = model_read.encode([text])
11
+ ans = 'not readable'
12
+ if lr_clf_finbert.predict(emd)==1:
13
+ ans = 'readable'
14
+ score = round(lr_clf_finbert.predict_proba(emd)[0,1],4)
15
+ return score
16
+
17
+ # Reference : https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base
18
+ tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
20
+
21
+ def paraphrase(
22
+ question,
23
+ num_beams=5,
24
+ num_beam_groups=5,
25
+ num_return_sequences=5,
26
+ repetition_penalty=10.0,
27
+ diversity_penalty=3.0,
28
+ no_repeat_ngram_size=2,
29
+ temperature=0.7,
30
+ max_length=128
31
+ ):
32
+ input_ids = tokenizer(
33
+ f'paraphrase: {question}',
34
+ return_tensors="pt", padding="longest",
35
+ max_length=max_length,
36
+ truncation=True,
37
+ ).input_ids
38
+
39
+ outputs = model.generate(
40
+ input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
41
+ num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
42
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
43
+ max_length=max_length, diversity_penalty=diversity_penalty
44
+ )
45
+
46
+ res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
47
+
48
+ return res
49
+
50
+ def get_most_raedable_paraphrse(text):
51
+ li_paraphrases = paraphrase(text)
52
+ li_paraphrases.append(text)
53
+ best = li_paraphrases[0]
54
+ score_max = get_readability(best)
55
+ for i in range(1,len(li_paraphrases)):
56
+ curr = li_paraphrases[i]
57
+ score = get_readability(curr)
58
+ if score > score_max:
59
+ best = curr
60
+ score_max = score
61
+ if best!=text and score_max>.6:
62
+ ans = "The most redable version of text that I can think of is:\n" + best
63
+ else:
64
+ "Sorry! I am not confident. As per my best knowledge, you already have the most readable version of the text!"
65
+ return ans
66
+
67
+ def set_example_text(example_text):
68
+ return gr.Textbox.update(value=example_text[0])
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown(
72
+ """
73
+ # FinLanSer
74
+ Financial Language Simplifier
75
+ """)
76
+ text = gr.Textbox(label="Enter text you want to simply (make more readable)")
77
+ greet_btn = gr.Button("Simplify/Make Readable")
78
+ output = gr.Textbox(label="Output Box")
79
+ greet_btn.click(fn=get_most_raedable_paraphrse, inputs=text, outputs=output, api_name="get_most_raedable_paraphrse")
80
+ example_text = gr.Dataset(components=[text], samples=[['Inflation is the rate of increase in prices over a given period of time. Inflation is typically a broad measure, such as the overall increase in prices or the increase in the cost of living in a country.'], ['Legally assured line of credit with a bank'], ['A mutual fund is a type of financial vehicle made up of a pool of money collected from many investors to invest in securities like stocks, bonds, money market instruments']])
81
+ example_text.click(fn=set_example_text, inputs=example_text,outputs=example_text.components)
82
+
83
+ demo.launch()