Vaibhav Srivastav commited on
Commit
03bcbd5
1 Parent(s): 4b68a20

initial commit v2

Browse files
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -1,7 +1,40 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
+ from cProfile import label
2
  import gradio as gr
3
+ import transformers
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained("wanyu/IteraTeR-PEGASUS-Revision-Generator")
6
+ model = AutoModelForSeq2SeqLM.from_pretrained("wanyu/IteraTeR-PEGASUS-Revision-Generator")
7
 
8
+ def prep_input(text):
9
+ text = text.strip()
10
+ clarity_input = "<clarity> " + text
11
+ fluency_input = "<fluency> " + text
12
+ coherence_input = "<coherence> " + text
13
+ style_input = "<style> " + text
14
+ return [clarity_input, fluency_input, coherence_input, style_input]
15
+
16
+ def get_model_output(text):
17
+ model_input = tokenizer(text, return_tensors='pt')
18
+ model_outputs = model.generate(**model_input, num_beams=8, max_length=1024)
19
+ pred = tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0]
20
+ return pred
21
+
22
+ def return_predictions(text):
23
+ all_predictions = []
24
+ prepped_input = prep_input(text)
25
+ for input in prepped_input:
26
+ all_predictions.append(get_model_output(input))
27
+ return all_predictions[0], all_predictions[1], all_predictions[2], all_predictions[3]
28
+
29
+ iface = gr.Interface(fn=return_predictions,
30
+ inputs=gr.inputs.Textbox(label="Sentence/ Paragraph"),
31
+ outputs = [gr.outputs.Textbox(label="Clarity"),
32
+ gr.outputs.Textbox(label="Fluency"),
33
+ gr.outputs.Textbox(label="Coherence"),
34
+ gr.outputs.Textbox(label="Style")],
35
+ title="ITERATER: Understanding Iterative Revision from Human-Written text",
36
+ description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
37
+ layout = "horizontal",
38
+ theme="huggingface",
39
+ enable_queue=True)
40
  iface.launch()