tomato commited on
Commit
7d93f13
1 Parent(s): 4db7dc4

from [pipeline] to using [Model Instruction]

Browse files
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -1,14 +1,41 @@
1
  import gradio as gr
2
  import torch
3
  from tqdm import tqdm
4
- from transformers import pipeline
 
5
 
6
  MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
7
 
8
- summarizer = pipeline(task="summarization", model=MODEL_NAME)
 
 
 
9
 
10
  def summarize(text):
11
- return summarizer(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  demo = gr.Blocks(title="⭐ Summ4rizer ⭐")
14
  demo.encrypt = False
 
1
  import gradio as gr
2
  import torch
3
  from tqdm import tqdm
4
+ import re
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
  MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
8
 
9
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
 
14
  def summarize(text):
15
+
16
+ input_ids = tokenizer(
17
+ [WHITESPACE_HANDLER(text)],
18
+ return_tensors="pt",
19
+ padding="max_length",
20
+ truncation=True,
21
+ max_length=512
22
+ )["input_ids"]
23
+
24
+ output_ids = model.generate(
25
+ input_ids=input_ids,
26
+ max_length=84,
27
+ no_repeat_ngram_size=2,
28
+ num_beams=4
29
+ )[0]
30
+
31
+ summary = tokenizer.decode(
32
+ output_ids,
33
+ skip_special_tokens=True,
34
+ clean_up_tokenization_spaces=False
35
+ )
36
+ return summary
37
+
38
+
39
 
40
  demo = gr.Blocks(title="⭐ Summ4rizer ⭐")
41
  demo.encrypt = False