samariddin commited on
Commit
593b58b
1 Parent(s): 1ea6a90
Files changed (2) hide show
  1. app.py +29 -12
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,22 +1,39 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import BertTokenizerFast, EncoderDecoderModel
4
 
5
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
- ckpt = "csebuetnlp/mT5_multilingual_XLSum"
7
- tokenizer = BertTokenizerFast.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
8
- model = EncoderDecoderModel.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
 
9
 
10
  def generate_summary(text):
11
 
12
- inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
13
- input_ids = inputs.input_ids.to(device)
14
- attention_mask = inputs.attention_mask.to(device)
15
- output = model.generate(input_ids, attention_mask=attention_mask)
16
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  demo = gr.Interface(fn=generate_summary,
19
- inputs=gr.Textbox(lines=10, placeholder="Insert the text here"),
20
  outputs=gr.Textbox(lines=4)
21
  )
22
 
 
1
  import gradio as gr
2
+ import re
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
6
+
7
+ model_name = "csebuetnlp/mT5_multilingual_XLSum"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  def generate_summary(text):
12
 
13
+ input_ids = tokenizer(
14
+ [WHITESPACE_HANDLER(text)],
15
+ return_tensors="pt",
16
+ padding="max_length",
17
+ truncation=True,
18
+ max_length=512)["input_ids"]
19
+
20
+ output_ids = model.generate(
21
+ input_ids=input_ids,
22
+ max_length=84,
23
+ no_repeat_ngram_size=2,
24
+ num_beams=4
25
+ )[0]
26
+
27
+ summary = tokenizer.decode(
28
+ output_ids,
29
+ skip_special_tokens=True,
30
+ clean_up_tokenization_spaces=False
31
+ )
32
+
33
+ return summary
34
 
35
  demo = gr.Interface(fn=generate_summary,
36
+ inputs=gr.Textbox(lines=10, placeholder="Matinni kiriting!"),
37
  outputs=gr.Textbox(lines=4)
38
  )
39
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gradio
2
- torch
3
  transformers
 
1
  gradio
2
+ re
3
  transformers