Someman commited on
Commit
b148f92
1 Parent(s): d8a3c34

updated app.py #1

Browse files
Files changed (1) hide show
  1. app.py +35 -4
app.py CHANGED
@@ -1,12 +1,43 @@
1
  import gradio as gr
 
2
 
3
- title = "Nepali News Summarization"
4
 
 
5
 
6
- demo = gr.load(
7
- "GenzNepal/mt5-summarize-nepalil",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  inputs=gr.Textbox(lines=5, max_lines=20, label="Input Text"),
9
- title=title,
10
  )
11
 
12
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
 
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # Predict with test data (first 5 rows)
7
 
8
+ model_ckpt = "GenzNepal/mt5-summarize-nepali"
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ t5_tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
13
+
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)
15
+
16
+
17
+
18
+ def summarize(text):
19
+ inputs = t5_tokenizer(text, return_tensors="pt", max_length=1024, padding= "max_length", truncation=True, add_special_tokens=True)
20
+ generation = model.generate(
21
+ input_ids = inputs['input_ids'].to(device),
22
+ attention_mask=inputs['attention_mask'].to(device),
23
+ num_beams=6,
24
+ num_return_sequences=1,
25
+ no_repeat_ngram_size=2,
26
+ repetition_penalty=1.0,
27
+ min_length=100,
28
+ max_length=250,
29
+ length_penalty=2.0,
30
+ early_stopping=True
31
+ )
32
+ # # Convert id tokens to text
33
+ output = t5_tokenizer.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
34
+ return output
35
+
36
+
37
+ demo = gr.interface(
38
+ summarize,
39
  inputs=gr.Textbox(lines=5, max_lines=20, label="Input Text"),
40
+ outputs="Summarization"
41
  )
42
 
43
  if __name__ == "__main__":